lili_code/runner/net_utils.py

44 lines
1.4 KiB
Python

import torch
import os
from torch import nn
import numpy as np
import torch.nn.functional
from termcolor import colored
from .logger import get_logger
def save_model(net, optim, scheduler, recorder, is_best=False):
model_dir = os.path.join(recorder.work_dir, 'ckpt')
os.system('mkdir -p {}'.format(model_dir))
epoch = recorder.epoch
ckpt_name = 'best' if is_best else epoch
torch.save({
'net': net.state_dict(),
'optim': optim.state_dict(),
'scheduler': scheduler.state_dict(),
'recorder': recorder.state_dict(),
'epoch': epoch
}, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))
def load_network_specified(net, model_dir, logger=None):
pretrained_net = torch.load(model_dir)['net']
net_state = net.state_dict()
state = {}
for k, v in pretrained_net.items():
if k not in net_state.keys() or v.size() != net_state[k].size():
if logger:
logger.info('skip weights: ' + k)
continue
state[k] = v
net.load_state_dict(state, strict=False)
def load_network(net, model_dir, finetune_from=None, logger=None):
if finetune_from:
if logger:
logger.info('Finetune model from: ' + finetune_from)
load_network_specified(net, finetune_from, logger)
return
pretrained_model = torch.load(model_dir)
net.load_state_dict(pretrained_model['net'], strict=True)