152 lines
5.5 KiB
Python
152 lines
5.5 KiB
Python
|
import os
|
|||
|
|
|||
|
import torch
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
from utils.utils import get_lr
|
|||
|
|
|||
|
|
|||
|
def fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step,
|
|||
|
epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
|
|||
|
loss = 0
|
|||
|
val_loss = 0
|
|||
|
|
|||
|
if local_rank == 0:
|
|||
|
print('Start Train')
|
|||
|
pbar = tqdm(total=epoch_step, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)
|
|||
|
model_train.train() # 调整所有的模块为train模式
|
|||
|
for iteration, batch in enumerate(gen):
|
|||
|
if iteration >= epoch_step: # 有什么意义?
|
|||
|
break
|
|||
|
|
|||
|
images, targets = batch[0], batch[1] # targets也是归一化了的
|
|||
|
with torch.no_grad():
|
|||
|
if cuda:
|
|||
|
images = images.cuda(local_rank)
|
|||
|
targets = [ann.cuda(local_rank) for ann in
|
|||
|
targets] # targets是一个python的list,里面是tensor,把tensor逐个转到cuda上,然后targets还是python的列表
|
|||
|
# ----------------------#
|
|||
|
# 清零梯度
|
|||
|
# ----------------------#
|
|||
|
optimizer.zero_grad()
|
|||
|
if not fp16:
|
|||
|
# ----------------------#
|
|||
|
# 前向传播
|
|||
|
# ----------------------#
|
|||
|
outputs = model_train(images)
|
|||
|
|
|||
|
loss_value_all = 0
|
|||
|
# ----------------------#
|
|||
|
# 计算损失
|
|||
|
# ----------------------#
|
|||
|
for l in range(len(outputs)): # 三组不同分辨率大小的输出特征分别计算
|
|||
|
loss_item = yolo_loss(l, outputs[l], targets)
|
|||
|
loss_value_all += loss_item
|
|||
|
loss_value = loss_value_all
|
|||
|
|
|||
|
# ----------------------#
|
|||
|
# 反向传播
|
|||
|
# ----------------------#
|
|||
|
loss_value.backward()
|
|||
|
optimizer.step()
|
|||
|
else: # 不进入这条分支
|
|||
|
from torch.cuda.amp import autocast
|
|||
|
with autocast():
|
|||
|
# ----------------------#
|
|||
|
# 前向传播
|
|||
|
# ----------------------#
|
|||
|
outputs = model_train(images)
|
|||
|
|
|||
|
loss_value_all = 0
|
|||
|
# ----------------------#
|
|||
|
# 计算损失
|
|||
|
# ----------------------#
|
|||
|
for l in range(len(outputs)):
|
|||
|
loss_item = yolo_loss(l, outputs[l], targets)
|
|||
|
loss_value_all += loss_item
|
|||
|
loss_value = loss_value_all
|
|||
|
|
|||
|
# ----------------------#
|
|||
|
# 反向传播
|
|||
|
# ----------------------#
|
|||
|
scaler.scale(loss_value).backward()
|
|||
|
scaler.step(optimizer)
|
|||
|
scaler.update()
|
|||
|
|
|||
|
loss += loss_value.item()
|
|||
|
|
|||
|
# # 调试用 begin
|
|||
|
# if iteration > 2:
|
|||
|
# break
|
|||
|
# # 调试用 end
|
|||
|
|
|||
|
if local_rank == 0:
|
|||
|
pbar.set_postfix(**{'loss': loss / (iteration + 1),
|
|||
|
'lr': get_lr(optimizer)})
|
|||
|
pbar.update(1)
|
|||
|
|
|||
|
if local_rank == 0:
|
|||
|
pbar.close()
|
|||
|
print('Finish Train')
|
|||
|
print('Start Validation')
|
|||
|
pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)
|
|||
|
|
|||
|
model_train.eval()
|
|||
|
for iteration, batch in enumerate(gen_val):
|
|||
|
if iteration >= epoch_step_val:
|
|||
|
break
|
|||
|
images, targets = batch[0], batch[1]
|
|||
|
with torch.no_grad():
|
|||
|
if cuda:
|
|||
|
images = images.cuda(local_rank)
|
|||
|
targets = [ann.cuda(local_rank) for ann in targets]
|
|||
|
# ----------------------#
|
|||
|
# 清零梯度
|
|||
|
# ----------------------#
|
|||
|
optimizer.zero_grad()
|
|||
|
# ----------------------#
|
|||
|
# 前向传播
|
|||
|
# ----------------------#
|
|||
|
outputs = model_train(images)
|
|||
|
|
|||
|
loss_value_all = 0
|
|||
|
# ----------------------#
|
|||
|
# 计算损失
|
|||
|
# ----------------------#
|
|||
|
for l in range(len(outputs)):
|
|||
|
loss_item = yolo_loss(l, outputs[l], targets)
|
|||
|
loss_value_all += loss_item
|
|||
|
loss_value = loss_value_all
|
|||
|
|
|||
|
val_loss += loss_value.item()
|
|||
|
|
|||
|
# # 调试用 begin
|
|||
|
# if iteration > 2:
|
|||
|
# break
|
|||
|
# # 调试用 end
|
|||
|
|
|||
|
if local_rank == 0:
|
|||
|
pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
|
|||
|
pbar.update(1)
|
|||
|
|
|||
|
if local_rank == 0:
|
|||
|
pbar.close()
|
|||
|
print('Finish Validation')
|
|||
|
loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
|
|||
|
eval_callback.on_epoch_end(epoch + 1, model_train)
|
|||
|
print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
|
|||
|
print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
|
|||
|
|
|||
|
# -----------------------------------------------#
|
|||
|
# 保存权值
|
|||
|
# -----------------------------------------------#
|
|||
|
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
|
|||
|
torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (
|
|||
|
epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
|
|||
|
|
|||
|
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
|
|||
|
print('Save best model to best_epoch_weights.pth')
|
|||
|
torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
|
|||
|
|
|||
|
torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
|