218 lines
7.6 KiB
Python
218 lines
7.6 KiB
Python
|
import os,random,copy
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import argparse
|
||
|
import yaml
|
||
|
import shutil
|
||
|
import tensorboard_logger as tb_logger
|
||
|
import logging
|
||
|
import click
|
||
|
|
||
|
import utils
|
||
|
import data
|
||
|
import engine
|
||
|
|
||
|
from vocab import deserialize_vocab
|
||
|
|
||
|
def parser_options():
|
||
|
# Hyper Parameters setting
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--path_opt', default='option/SYDNEY_GAC.yaml', type=str,
|
||
|
help='path to a yaml options file')
|
||
|
# parser.add_argument('--text_sim_path', default='data/ucm_precomp/train_caps.npy', type=str,help='path to t2t sim matrix')
|
||
|
opt = parser.parse_args()
|
||
|
|
||
|
# load model options
|
||
|
with open(opt.path_opt, 'r') as handle:
|
||
|
options = yaml.safe_load(handle)
|
||
|
|
||
|
return options
|
||
|
|
||
|
def main(options):
|
||
|
# choose model
|
||
|
if options['model']['name'] == "GAC":
|
||
|
from model import GAC as models
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
# make ckpt save dir
|
||
|
if not os.path.exists(options['logs']['ckpt_save_path']):
|
||
|
os.makedirs(options['logs']['ckpt_save_path'])
|
||
|
|
||
|
# make vocab
|
||
|
vocab = deserialize_vocab(options['dataset']['vocab_path'])
|
||
|
# word2idx = vocab.word2idx
|
||
|
# vocab_size = len(vocab)
|
||
|
vocab_word = sorted(vocab.word2idx.items(), key=lambda x: x[1], reverse=False)
|
||
|
vocab_word = [tup[0] for tup in vocab_word]
|
||
|
# print("vocab_word : ", vocab_word)
|
||
|
|
||
|
# Create dataset, model, criterion and optimizer
|
||
|
train_loader, val_loader = data.get_loaders(vocab, options)
|
||
|
|
||
|
model = models.factory(options['model'],
|
||
|
vocab,
|
||
|
cuda=True,
|
||
|
data_parallel=False)
|
||
|
|
||
|
# the param with grad wil be updated
|
||
|
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
|
||
|
lr=options['optim']['lr'])
|
||
|
|
||
|
print('Model has {} parameters'.format(utils.params_count(model)))
|
||
|
|
||
|
# optionally resume from a checkpoint
|
||
|
if options['optim']['resume']:
|
||
|
if os.path.isfile(options['optim']['resume']):
|
||
|
print("=> loading checkpoint '{}'".format(options['optim']['resume']))
|
||
|
checkpoint = torch.load(options['optim']['resume'])
|
||
|
start_epoch = checkpoint['epoch']
|
||
|
best_rsum = checkpoint['best_rsum']
|
||
|
model.load_state_dict(checkpoint['model'])
|
||
|
|
||
|
# Eiters is used to show logs as the continuation of another
|
||
|
# training
|
||
|
model.Eiters = checkpoint['Eiters']
|
||
|
|
||
|
print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
|
||
|
.format(options['optim']['resume'], start_epoch, best_rsum))
|
||
|
rsum, all_scores = engine.validate(val_loader, model)
|
||
|
print(all_scores)
|
||
|
else:
|
||
|
print("=> no checkpoint found at '{}'".format(options['optim']['resume']))
|
||
|
else:
|
||
|
start_epoch = 0
|
||
|
|
||
|
# Train the Model
|
||
|
best_rsum = 0
|
||
|
best_score = ""
|
||
|
|
||
|
flag = 0 # mark if or not freeze
|
||
|
for epoch in range(start_epoch, options['optim']['epochs']):
|
||
|
# if epoch < 20:
|
||
|
# if flag == 0:
|
||
|
# for param in model.text_feature.embed.parameters():
|
||
|
# param.requires_grad = False
|
||
|
# flag = 1
|
||
|
# else:
|
||
|
# if flag == 1:
|
||
|
# for param in model.text_feature.embed.parameters():
|
||
|
# param.requires_grad = True
|
||
|
# flag = 2
|
||
|
|
||
|
utils.adjust_learning_rate(options, optimizer, epoch)
|
||
|
|
||
|
# train for one epoch
|
||
|
engine.train(train_loader, model, optimizer, epoch, opt=options)
|
||
|
|
||
|
# evaluate on validation set
|
||
|
if epoch % options['logs']['eval_step'] == 0:
|
||
|
rsum, all_scores = engine.validate(val_loader, model)
|
||
|
need_save = epoch >= 5
|
||
|
is_best = rsum > best_rsum
|
||
|
need_save = is_best or (rsum - best_rsum > -1)
|
||
|
if is_best:
|
||
|
best_score = all_scores
|
||
|
best_rsum = max(rsum, best_rsum)
|
||
|
|
||
|
# save ckpt
|
||
|
utils.save_checkpoint(
|
||
|
{
|
||
|
'epoch': epoch + 1,
|
||
|
'arch': 'baseline',
|
||
|
'model': model.state_dict(),
|
||
|
'best_rsum': best_rsum,
|
||
|
'options': options,
|
||
|
},
|
||
|
need_save,
|
||
|
is_best,
|
||
|
filename='ckpt_{}_{}_{:.2f}.pth.tar'.format(options['model']['name'] ,epoch, rsum),
|
||
|
prefix=options['logs']['ckpt_save_path'],
|
||
|
model_name=options['model']['name']
|
||
|
)
|
||
|
|
||
|
print("Current {}th fold.".format(options['k_fold']['current_num']))
|
||
|
print("Now score:")
|
||
|
print(all_scores)
|
||
|
print("Best score:")
|
||
|
print(best_score)
|
||
|
|
||
|
utils.log_to_txt(
|
||
|
contexts= "Epoch:{} ".format(epoch+1) + all_scores,
|
||
|
filename=options['logs']['ckpt_save_path']+ options['model']['name'] + "_" + options['dataset']['datatype'] +".txt"
|
||
|
)
|
||
|
utils.log_to_txt(
|
||
|
contexts= "Best: " + best_score,
|
||
|
filename=options['logs']['ckpt_save_path']+ options['model']['name'] + "_" + options['dataset']['datatype'] +".txt"
|
||
|
)
|
||
|
|
||
|
def generate_random_samples(options):
|
||
|
# load all anns
|
||
|
caps = utils.load_from_txt(options['dataset']['data_path']+'train_caps.txt')
|
||
|
fnames = utils.load_from_txt(options['dataset']['data_path']+'train_filename.txt')
|
||
|
|
||
|
# merge
|
||
|
assert len(caps) // 5 == len(fnames)
|
||
|
all_infos = []
|
||
|
for img_id in range(len(fnames)):
|
||
|
cap_id = [img_id * 5 ,(img_id+1) * 5]
|
||
|
all_infos.append([caps[cap_id[0]:cap_id[1]], fnames[img_id]])
|
||
|
|
||
|
# shuffle
|
||
|
random.shuffle(all_infos)
|
||
|
|
||
|
# split_trainval
|
||
|
percent = 0.8
|
||
|
train_infos = all_infos[:int(len(all_infos)*percent)]
|
||
|
val_infos = all_infos[int(len(all_infos)*percent):]
|
||
|
|
||
|
# save to txt
|
||
|
train_caps = []
|
||
|
train_fnames = []
|
||
|
for item in train_infos:
|
||
|
for cap in item[0]:
|
||
|
train_caps.append(cap)
|
||
|
train_fnames.append(item[1])
|
||
|
utils.log_to_txt(train_caps, options['dataset']['data_path']+'train_caps_verify.txt',mode='w')
|
||
|
utils.log_to_txt(train_fnames, options['dataset']['data_path']+'train_filename_verify.txt',mode='w')
|
||
|
|
||
|
val_caps = []
|
||
|
val_fnames = []
|
||
|
for item in val_infos:
|
||
|
for cap in item[0]:
|
||
|
val_caps.append(cap)
|
||
|
val_fnames.append(item[1])
|
||
|
utils.log_to_txt(val_caps, options['dataset']['data_path']+'val_caps_verify.txt',mode='w')
|
||
|
utils.log_to_txt(val_fnames, options['dataset']['data_path']+'val_filename_verify.txt',mode='w')
|
||
|
|
||
|
print("Generate random samples to {} complete.".format(options['dataset']['data_path']))
|
||
|
|
||
|
def update_options_savepath(options, k):
|
||
|
updated_options = copy.deepcopy(options)
|
||
|
|
||
|
updated_options['k_fold']['current_num'] = k
|
||
|
updated_options['logs']['ckpt_save_path'] = options['logs']['ckpt_save_path'] + \
|
||
|
options['k_fold']['experiment_name'] + "/" + str(k) + "/"
|
||
|
return updated_options
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
options = parser_options()
|
||
|
|
||
|
# make logger
|
||
|
tb_logger.configure(options['logs']['logger_name'], flush_secs=5)
|
||
|
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
||
|
|
||
|
# k_fold verify
|
||
|
for k in range(options['k_fold']['nums']):
|
||
|
print("=========================================")
|
||
|
print("Start {}th fold".format(k))
|
||
|
|
||
|
# generate random train and val samples
|
||
|
generate_random_samples(options)
|
||
|
|
||
|
# update save path
|
||
|
update_options = update_options_savepath(options, k)
|
||
|
|
||
|
# run experiment
|
||
|
main(update_options)
|