lili_code/runner/runner.py

113 lines
4.3 KiB
Python

import time
import torch
import numpy as np
from tqdm import tqdm
import pytorch_warmup as warmup
from models.registry import build_net
from .registry import build_trainer, build_evaluator
from .optimizer import build_optimizer
from .scheduler import build_scheduler
from datasets import build_dataloader
from .recorder import build_recorder
from .net_utils import save_model, load_network
class Runner(object):
def __init__(self, cfg):
self.cfg = cfg
self.recorder = build_recorder(self.cfg)
self.net = build_net(self.cfg)
self.net = torch.nn.parallel.DataParallel(
self.net, device_ids = range(self.cfg.gpus)).cuda()
self.recorder.logger.info('Network: \n' + str(self.net))
self.resume()
self.optimizer = build_optimizer(self.cfg, self.net)
self.scheduler = build_scheduler(self.cfg, self.optimizer)
self.evaluator = build_evaluator(self.cfg)
self.warmup_scheduler = warmup.LinearWarmup(
self.optimizer, warmup_period=5000)
self.metric = 0.
def resume(self):
if not self.cfg.load_from and not self.cfg.finetune_from:
return
load_network(self.net, self.cfg.load_from,
finetune_from=self.cfg.finetune_from, logger=self.recorder.logger)
def to_cuda(self, batch):
for k in batch:
if k == 'meta':
continue
batch[k] = batch[k].cuda()
return batch
def train_epoch(self, epoch, train_loader):
self.net.train()
end = time.time()
max_iter = len(train_loader)
for i, data in enumerate(train_loader):
if self.recorder.step >= self.cfg.total_iter:
break
date_time = time.time() - end
self.recorder.step += 1
data = self.to_cuda(data)
output = self.trainer.forward(self.net, data)
self.optimizer.zero_grad()
loss = output['loss']
loss.backward()
self.optimizer.step()
self.scheduler.step()
self.warmup_scheduler.dampen()
batch_time = time.time() - end
end = time.time()
self.recorder.update_loss_stats(output['loss_stats'])
self.recorder.batch_time.update(batch_time)
self.recorder.data_time.update(date_time)
if i % self.cfg.log_interval == 0 or i == max_iter - 1:
lr = self.optimizer.param_groups[0]['lr']
self.recorder.lr = lr
self.recorder.record('train')
def train(self):
self.recorder.logger.info('start training...')
self.trainer = build_trainer(self.cfg)
train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True)
val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False)
for epoch in range(self.cfg.epochs):
print('Epoch: [{}/{}]'.format(self.recorder.step, self.cfg.total_iter))
print('Epoch: [{}/{}]'.format(epoch, self.cfg.epochs))
self.recorder.epoch = epoch
self.train_epoch(epoch, train_loader)
if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1:
self.save_ckpt()
if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1:
self.validate(val_loader)
if self.recorder.step >= self.cfg.total_iter:
break
def validate(self, val_loader):
self.net.eval()
count = 10
for i, data in enumerate(tqdm(val_loader, desc=f'Validate')):
start_time = time.time()
data = self.to_cuda(data)
with torch.no_grad():
output = self.net(data['img'])
self.evaluator.evaluate(val_loader.dataset, output, data)
# print("第{}张图片检测花了{}秒".format(i,time.time()-start_time))
metric = self.evaluator.summarize()
if not metric:
return
if metric > self.metric:
self.metric = metric
self.save_ckpt(is_best=True)
self.recorder.logger.info('Best metric: ' + str(self.metric))
def save_ckpt(self, is_best=False):
save_model(self.net, self.optimizer, self.scheduler,
self.recorder, is_best)