Graduation_Project/WZM/train.py

218 lines
7.6 KiB
Python

import os,random,copy
import torch
import torch.nn as nn
import argparse
import yaml
import shutil
import tensorboard_logger as tb_logger
import logging
import click
import utils
import data
import engine
from vocab import deserialize_vocab
def parser_options():
# Hyper Parameters setting
parser = argparse.ArgumentParser()
parser.add_argument('--path_opt', default='option/SYDNEY_GAC.yaml', type=str,
help='path to a yaml options file')
# parser.add_argument('--text_sim_path', default='data/ucm_precomp/train_caps.npy', type=str,help='path to t2t sim matrix')
opt = parser.parse_args()
# load model options
with open(opt.path_opt, 'r') as handle:
options = yaml.safe_load(handle)
return options
def main(options):
# choose model
if options['model']['name'] == "GAC":
from model import GAC as models
else:
raise NotImplementedError
# make ckpt save dir
if not os.path.exists(options['logs']['ckpt_save_path']):
os.makedirs(options['logs']['ckpt_save_path'])
# make vocab
vocab = deserialize_vocab(options['dataset']['vocab_path'])
# word2idx = vocab.word2idx
# vocab_size = len(vocab)
vocab_word = sorted(vocab.word2idx.items(), key=lambda x: x[1], reverse=False)
vocab_word = [tup[0] for tup in vocab_word]
# print("vocab_word : ", vocab_word)
# Create dataset, model, criterion and optimizer
train_loader, val_loader = data.get_loaders(vocab, options)
model = models.factory(options['model'],
vocab,
cuda=True,
data_parallel=False)
# the param with grad wil be updated
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=options['optim']['lr'])
print('Model has {} parameters'.format(utils.params_count(model)))
# optionally resume from a checkpoint
if options['optim']['resume']:
if os.path.isfile(options['optim']['resume']):
print("=> loading checkpoint '{}'".format(options['optim']['resume']))
checkpoint = torch.load(options['optim']['resume'])
start_epoch = checkpoint['epoch']
best_rsum = checkpoint['best_rsum']
model.load_state_dict(checkpoint['model'])
# Eiters is used to show logs as the continuation of another
# training
model.Eiters = checkpoint['Eiters']
print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
.format(options['optim']['resume'], start_epoch, best_rsum))
rsum, all_scores = engine.validate(val_loader, model)
print(all_scores)
else:
print("=> no checkpoint found at '{}'".format(options['optim']['resume']))
else:
start_epoch = 0
# Train the Model
best_rsum = 0
best_score = ""
flag = 0 # mark if or not freeze
for epoch in range(start_epoch, options['optim']['epochs']):
# if epoch < 20:
# if flag == 0:
# for param in model.text_feature.embed.parameters():
# param.requires_grad = False
# flag = 1
# else:
# if flag == 1:
# for param in model.text_feature.embed.parameters():
# param.requires_grad = True
# flag = 2
utils.adjust_learning_rate(options, optimizer, epoch)
# train for one epoch
engine.train(train_loader, model, optimizer, epoch, opt=options)
# evaluate on validation set
if epoch % options['logs']['eval_step'] == 0:
rsum, all_scores = engine.validate(val_loader, model)
need_save = epoch >= 5
is_best = rsum > best_rsum
need_save = is_best or (rsum - best_rsum > -1)
if is_best:
best_score = all_scores
best_rsum = max(rsum, best_rsum)
# save ckpt
utils.save_checkpoint(
{
'epoch': epoch + 1,
'arch': 'baseline',
'model': model.state_dict(),
'best_rsum': best_rsum,
'options': options,
},
need_save,
is_best,
filename='ckpt_{}_{}_{:.2f}.pth.tar'.format(options['model']['name'] ,epoch, rsum),
prefix=options['logs']['ckpt_save_path'],
model_name=options['model']['name']
)
print("Current {}th fold.".format(options['k_fold']['current_num']))
print("Now score:")
print(all_scores)
print("Best score:")
print(best_score)
utils.log_to_txt(
contexts= "Epoch:{} ".format(epoch+1) + all_scores,
filename=options['logs']['ckpt_save_path']+ options['model']['name'] + "_" + options['dataset']['datatype'] +".txt"
)
utils.log_to_txt(
contexts= "Best: " + best_score,
filename=options['logs']['ckpt_save_path']+ options['model']['name'] + "_" + options['dataset']['datatype'] +".txt"
)
def generate_random_samples(options):
# load all anns
caps = utils.load_from_txt(options['dataset']['data_path']+'train_caps.txt')
fnames = utils.load_from_txt(options['dataset']['data_path']+'train_filename.txt')
# merge
assert len(caps) // 5 == len(fnames)
all_infos = []
for img_id in range(len(fnames)):
cap_id = [img_id * 5 ,(img_id+1) * 5]
all_infos.append([caps[cap_id[0]:cap_id[1]], fnames[img_id]])
# shuffle
random.shuffle(all_infos)
# split_trainval
percent = 0.8
train_infos = all_infos[:int(len(all_infos)*percent)]
val_infos = all_infos[int(len(all_infos)*percent):]
# save to txt
train_caps = []
train_fnames = []
for item in train_infos:
for cap in item[0]:
train_caps.append(cap)
train_fnames.append(item[1])
utils.log_to_txt(train_caps, options['dataset']['data_path']+'train_caps_verify.txt',mode='w')
utils.log_to_txt(train_fnames, options['dataset']['data_path']+'train_filename_verify.txt',mode='w')
val_caps = []
val_fnames = []
for item in val_infos:
for cap in item[0]:
val_caps.append(cap)
val_fnames.append(item[1])
utils.log_to_txt(val_caps, options['dataset']['data_path']+'val_caps_verify.txt',mode='w')
utils.log_to_txt(val_fnames, options['dataset']['data_path']+'val_filename_verify.txt',mode='w')
print("Generate random samples to {} complete.".format(options['dataset']['data_path']))
def update_options_savepath(options, k):
updated_options = copy.deepcopy(options)
updated_options['k_fold']['current_num'] = k
updated_options['logs']['ckpt_save_path'] = options['logs']['ckpt_save_path'] + \
options['k_fold']['experiment_name'] + "/" + str(k) + "/"
return updated_options
if __name__ == '__main__':
options = parser_options()
# make logger
tb_logger.configure(options['logs']['logger_name'], flush_secs=5)
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
# k_fold verify
for k in range(options['k_fold']['nums']):
print("=========================================")
print("Start {}th fold".format(k))
# generate random train and val samples
generate_random_samples(options)
# update save path
update_options = update_options_savepath(options, k)
# run experiment
main(update_options)