44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import torch
|
|
import os
|
|
from torch import nn
|
|
import numpy as np
|
|
import torch.nn.functional
|
|
from termcolor import colored
|
|
from .logger import get_logger
|
|
|
|
def save_model(net, optim, scheduler, recorder, is_best=False):
|
|
model_dir = os.path.join(recorder.work_dir, 'ckpt')
|
|
os.system('mkdir -p {}'.format(model_dir))
|
|
epoch = recorder.epoch
|
|
ckpt_name = 'best' if is_best else epoch
|
|
torch.save({
|
|
'net': net.state_dict(),
|
|
'optim': optim.state_dict(),
|
|
'scheduler': scheduler.state_dict(),
|
|
'recorder': recorder.state_dict(),
|
|
'epoch': epoch
|
|
}, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))
|
|
|
|
|
|
def load_network_specified(net, model_dir, logger=None):
|
|
pretrained_net = torch.load(model_dir)['net']
|
|
net_state = net.state_dict()
|
|
state = {}
|
|
for k, v in pretrained_net.items():
|
|
if k not in net_state.keys() or v.size() != net_state[k].size():
|
|
if logger:
|
|
logger.info('skip weights: ' + k)
|
|
continue
|
|
state[k] = v
|
|
net.load_state_dict(state, strict=False)
|
|
|
|
|
|
def load_network(net, model_dir, finetune_from=None, logger=None):
|
|
if finetune_from:
|
|
if logger:
|
|
logger.info('Finetune model from: ' + finetune_from)
|
|
load_network_specified(net, finetune_from, logger)
|
|
return
|
|
pretrained_model = torch.load(model_dir)
|
|
net.load_state_dict(pretrained_model['net'], strict=True)
|