본문 바로가기
카테고리 없음

level 2 Data-Centric, 4. 이어서 학습하기, (pth file 사용법)

by yooom 2024. 2. 5.

train.py, dataset.py를 주로 들어가봤겠지만

model.py는 수정하면 안된다고 해서 별로 들어갈 일이 없었을 것이다.

하지만 여기 수정하면 훈련시간 10배 줄어든다.

 

train.py 수정부분을 잠깐 보자. 원래는

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = EAST()
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[max_epoch // 3, max_epoch//3*2], gamma=0.01)
for epoch in range(max_epoch):
        model.train()
        epoch_loss, epoch_start, _epoch_loss = 0, time.time(), 0
        with tqdm(total=num_batches) as pbar:
        .
        .
        .

이렇게 EAST모델을 불러오고, optimizer, scheduler 선언하고 첫 epoch로 돌입한다.

그리고 model.py에서 pth를 검색해보면

vggnet을 받아오는 이 부분만 pth를 받아오게 돼있다.

 

이번 대회에서 사용하는 모델은 EAST모델이고 VGG는 일부분에 해당하므로 훈련 이후에 생성된 pth파일은 EAST모델에 해당하는 것이지, VGG에 해당하는 것이 아니다.

고로, EAST에서 load_state_dict을 받아오는 코드를 추가해줘야 한다.

 

한편, pth는 앞으로 훈련을 하면서 계속 갱신되고 바꿔줄 데이터이기 때문에 model.py가 아니라 train.py에서 쉽게 수정할 수 있게끔 코드를 수정해보자.

 

1. train.py

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = EAST() # 이걸 수정하여 새로운 파라미터를 받아오게 한다.
model = EAST(pretrained=False, saved_model_path='./save_pth/_extended_experiment/5.pth')
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[max_epoch // 3, max_epoch//3*2], gamma=0.01)

for epoch in range(max_epoch):
    model.train()
    epoch_loss, epoch_start, _epoch_loss = 0, time.time(), 0
    with tqdm(total=num_batches) as pbar:

 여기서 pretrained = False는

EAST 클래스에서 True가 defalut 값이고, vgg의 pth파일을 받아오는 조건이 pretrain=true이기 때문에, false로 하여 pth를 2개 불러오는 일이 없도록 하였다.

 

2. model.py

class EAST(nn.Module):
    # def __init__(self, pretrained=True): # 원래는 이렇게 설정돼있지만,
    def __init__(self, pretrained=True, saved_model_path=None): # 파라미터를 받아오며
        super(EAST, self).__init__()
        self.extractor = Extractor(pretrained)
        self.merge = Merge()
        self.output = Output()
        self.criterion = EASTLoss()
        if saved_model_path is not None: # 새로운 조건의 추가로
            self.load_state_dict(torch.load(saved_model_path)) # pth를 load해준다
 
    def forward(self, x):
        return self.output(self.merge(self.extractor(x)))

 

이렇게 하면 wandb 그래프의 양상이 확 달라지는 것을 볼 수 있을 것이다.

 

초록색은 baseline을 augmentation 없이 50epoch 수행했다.

파란색은 강한 augmentation을 입힌 뒤 수렴하지 못한 채 진동하여 35epoch에서 pth를 저장했다.

빨간색은 파란색의 pth를 이어받아 baseline을 augmentation 없이 50epoch 수행했다. 초록색과 같은 환경이지만 pretrain의 힘을 알 수 있는 대목이다.

분홍색은 빨간색의 pth를 이어받아 약한 augmentation을 수행한 그래프이다.

 

실제로 위 네 가지 훈련데이터를 test set에서 box를 출력해보면 파랑, 초록, 빨강, 분홍 순서대로 안정적인 box를 보여준다.

(파랑, 빨강, 분홍은 원본 100장에 추가 데이터200장을 추가하여 훈련했다.)

728x90

댓글