101 lines
3.0 KiB
Python
101 lines
3.0 KiB
Python
from collections import deque, defaultdict
|
|
import torch
|
|
import os
|
|
import datetime
|
|
from .logger import get_logger
|
|
|
|
|
|
class SmoothedValue(object):
|
|
"""Track a series of values and provide access to smoothed values over a
|
|
window or the global series average.
|
|
"""
|
|
|
|
def __init__(self, window_size=20):
|
|
self.deque = deque(maxlen=window_size)
|
|
self.total = 0.0
|
|
self.count = 0
|
|
|
|
def update(self, value):
|
|
self.deque.append(value)
|
|
self.count += 1
|
|
self.total += value
|
|
|
|
@property
|
|
def median(self):
|
|
d = torch.tensor(list(self.deque))
|
|
return d.median().item()
|
|
|
|
@property
|
|
def avg(self):
|
|
d = torch.tensor(list(self.deque))
|
|
return d.mean().item()
|
|
|
|
@property
|
|
def global_avg(self):
|
|
return self.total / self.count
|
|
|
|
|
|
class Recorder(object):
|
|
def __init__(self, cfg):
|
|
self.cfg = cfg
|
|
self.work_dir = self.get_work_dir()
|
|
cfg.work_dir = self.work_dir
|
|
self.log_path = os.path.join(self.work_dir, 'log.txt')
|
|
|
|
self.logger = get_logger('resa', self.log_path)
|
|
self.logger.info('Config: \n' + cfg.text)
|
|
|
|
# scalars
|
|
self.epoch = 0
|
|
self.step = 0
|
|
self.loss_stats = defaultdict(SmoothedValue)
|
|
self.batch_time = SmoothedValue()
|
|
self.data_time = SmoothedValue()
|
|
self.max_iter = self.cfg.total_iter
|
|
self.lr = 0.
|
|
|
|
def get_work_dir(self):
|
|
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size)
|
|
work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str)
|
|
if not os.path.exists(work_dir):
|
|
os.makedirs(work_dir)
|
|
return work_dir
|
|
|
|
def update_loss_stats(self, loss_dict):
|
|
for k, v in loss_dict.items():
|
|
self.loss_stats[k].update(v.detach().cpu())
|
|
|
|
def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
|
|
self.logger.info(self)
|
|
# self.write(str(self))
|
|
|
|
def write(self, content):
|
|
with open(self.log_path, 'a+') as f:
|
|
f.write(content)
|
|
f.write('\n')
|
|
|
|
def state_dict(self):
|
|
scalar_dict = {}
|
|
scalar_dict['step'] = self.step
|
|
return scalar_dict
|
|
|
|
def load_state_dict(self, scalar_dict):
|
|
self.step = scalar_dict['step']
|
|
|
|
def __str__(self):
|
|
loss_state = []
|
|
for k, v in self.loss_stats.items():
|
|
loss_state.append('{}: {:.4f}'.format(k, v.avg))
|
|
loss_state = ' '.join(loss_state)
|
|
|
|
recording_state = ' '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}'])
|
|
eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step)
|
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string)
|
|
|
|
|
|
def build_recorder(cfg):
|
|
return Recorder(cfg)
|
|
|