Graduation_Project/WZM/find_best_pth.py

113 lines
4.0 KiB
Python

import os, random, copy
import numpy as np
import torch
import argparse
import yaml
import logging
import utils
import data
import engine
from vocab import deserialize_vocab
import mytools
from model import GAC as models
def parser_options():
# Hyper Parameters setting
parser = argparse.ArgumentParser()
parser.add_argument('--path_opt', default='option/RSITMD_mca/RSITMD_GAC.yaml', type=str,
help='path to a yaml options file')
opt = parser.parse_args()
# load model options
with open(opt.path_opt, 'r') as handle:
options = yaml.safe_load(handle)
return options
def main(options, vocab):
# Create dataset, model, criterion and optimizer
test_loader = data.get_test_loader(vocab, options)
model = models.factory(options['model'],
vocab,
cuda=True,
data_parallel=False)
print('Model has {} parameters'.format(utils.params_count(model)))
# optionally resume from a checkpoint
if os.path.isfile(options['optim']['resume']):
print("=> loading checkpoint '{}'".format(options['optim']['resume']))
checkpoint = torch.load(options['optim']['resume'])
start_epoch = checkpoint['epoch']
best_rsum = checkpoint['best_rsum']
model.load_state_dict(checkpoint['model'])
else:
print("=> no checkpoint found at '{}'".format(options['optim']['resume']))
# evaluate on test set
sims = engine.validate_test(test_loader, model)
# get indicators
(r1i, r5i, r10i, medri, meanri), _ = utils.acc_i2t2(sims)
logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
(r1i, r5i, r10i, medri, meanri))
(r1t, r5t, r10t, medrt, meanrt), _ = utils.acc_t2i2(sims)
logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
(r1t, r5t, r10t, medrt, meanrt))
currscore = (r1t + r5t + r10t + r1i + r5i + r10i)/6.0
all_score = "r1i:{} r5i:{} r10i:{} medri:{} meanri:{}\n r1t:{} r5t:{} r10t:{} medrt:{} meanrt:{}\n sum:{}\n ------\n".format(
r1i, r5i, r10i, medri, meanri, r1t, r5t, r10t, medrt, meanrt, currscore
)
# 记录到输出文件中
outputfile_path = "RSICD_GAC_decay0.5_m0.2_without_m4m5.txt"
with open(outputfile_path, 'a') as file:
file.writelines(options['optim']['resume'])
file.write(all_score)
print(all_score)
return [r1i, r5i, r10i, r1t, r5t, r10t, currscore]
def get_allpth_score(options, k, vocab):
updated_options = copy.deepcopy(options)
scores = []
directory = options['logs']['ckpt_save_path'] + options['k_fold']['experiment_name'] + "/" + str(k)
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.tar'):
file_path = os.path.join(root, file)
updated_options['optim']['resume'] = file_path
# run experiment
one_score = main(updated_options, vocab)
scores.append(one_score)
return scores
if __name__ == '__main__':
options = parser_options()
# make vocab
vocab = deserialize_vocab(options['dataset']['vocab_path'])
vocab_word = sorted(vocab.word2idx.items(), key=lambda x: x[1], reverse=False)
vocab_word = [tup[0] for tup in vocab_word]
# calc ave k results
last_score = []
for k in range(options['k_fold']['nums']):
print("=========================================")
print("Start evaluate {}th fold".format(k))
scores = get_allpth_score(options, k, vocab)
last_score.extend(scores)
print("Complete evaluate {}th fold".format(k))
# average
print("===================== Ave Score ({}-fold verify) =================".format(options['k_fold']['nums']))
last_score = np.average(last_score, axis=0)
names = ['r1i', 'r5i', 'r10i', 'r1t', 'r5t', 'r10t', 'mr']
for name,score in zip(names, last_score):
print("{}:{}".format(name, score))