train.py, dataset.py를 주로 들어가봤겠지만
model.py는 수정하면 안된다고 해서 별로 들어갈 일이 없었을 것이다.
하지만 여기 수정하면 훈련시간 10배 줄어든다.
train.py 수정부분을 잠깐 보자. 원래는
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = EAST()
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[max_epoch // 3, max_epoch//3*2], gamma=0.01)
for epoch in range(max_epoch):
model.train()
epoch_loss, epoch_start, _epoch_loss = 0, time.time(), 0
with tqdm(total=num_batches) as pbar:
.
.
.
이렇게 EAST모델을 불러오고, optimizer, scheduler 선언하고 첫 epoch로 돌입한다.
그리고 model.py에서 pth를 검색해보면
vggnet을 받아오는 이 부분만 pth를 받아오게 돼있다.
이번 대회에서 사용하는 모델은 EAST모델이고 VGG는 일부분에 해당하므로 훈련 이후에 생성된 pth파일은 EAST모델에 해당하는 것이지, VGG에 해당하는 것이 아니다.
고로, EAST에서 load_state_dict을 받아오는 코드를 추가해줘야 한다.
한편, pth는 앞으로 훈련을 하면서 계속 갱신되고 바꿔줄 데이터이기 때문에 model.py가 아니라 train.py에서 쉽게 수정할 수 있게끔 코드를 수정해보자.
1. train.py
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = EAST() # 이걸 수정하여 새로운 파라미터를 받아오게 한다.
model = EAST(pretrained=False, saved_model_path='./save_pth/_extended_experiment/5.pth')
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[max_epoch // 3, max_epoch//3*2], gamma=0.01)
for epoch in range(max_epoch):
model.train()
epoch_loss, epoch_start, _epoch_loss = 0, time.time(), 0
with tqdm(total=num_batches) as pbar:
여기서 pretrained = False는
EAST 클래스에서 True가 defalut 값이고, vgg의 pth파일을 받아오는 조건이 pretrain=true이기 때문에, false로 하여 pth를 2개 불러오는 일이 없도록 하였다.
2. model.py
class EAST(nn.Module):
# def __init__(self, pretrained=True): # 원래는 이렇게 설정돼있지만,
def __init__(self, pretrained=True, saved_model_path=None): # 파라미터를 받아오며
super(EAST, self).__init__()
self.extractor = Extractor(pretrained)
self.merge = Merge()
self.output = Output()
self.criterion = EASTLoss()
if saved_model_path is not None: # 새로운 조건의 추가로
self.load_state_dict(torch.load(saved_model_path)) # pth를 load해준다
def forward(self, x):
return self.output(self.merge(self.extractor(x)))
이렇게 하면 wandb 그래프의 양상이 확 달라지는 것을 볼 수 있을 것이다.
초록색은 baseline을 augmentation 없이 50epoch 수행했다.
파란색은 강한 augmentation을 입힌 뒤 수렴하지 못한 채 진동하여 35epoch에서 pth를 저장했다.
빨간색은 파란색의 pth를 이어받아 baseline을 augmentation 없이 50epoch 수행했다. 초록색과 같은 환경이지만 pretrain의 힘을 알 수 있는 대목이다.
분홍색은 빨간색의 pth를 이어받아 약한 augmentation을 수행한 그래프이다.
실제로 위 네 가지 훈련데이터를 test set에서 box를 출력해보면 파랑, 초록, 빨강, 분홍 순서대로 안정적인 box를 보여준다.
(파랑, 빨강, 분홍은 원본 100장에 추가 데이터200장을 추가하여 훈련했다.)
댓글