Graduation_Project/WZM/utils.py

441 lines
14 KiB
Python
Raw Permalink Normal View History

2024-06-24 19:41:48 +08:00
import torch
import numpy as np
import sys
import math
from torch.autograd import Variable
from collections import OrderedDict
import torch.nn as nn
import shutil
import time
from model.utils import cosine_sim, cosine_similarity
# 从npy中读取
def load_from_npy(filename):
info = np.load(filename, allow_pickle=True)
return info
# 保存结果到txt文件
def log_to_txt( contexts=None,filename="save.txt", mark=False,encoding='UTF-8',mode='a'):
f = open(filename, mode,encoding=encoding)
if mark:
sig = "------------------------------------------------\n"
f.write(sig)
elif isinstance(contexts, dict):
tmp = ""
for c in contexts.keys():
tmp += str(c)+" | "+ str(contexts[c]) +"\n"
contexts = tmp
f.write(contexts)
else:
if isinstance(contexts,list):
tmp = ""
for c in contexts:
tmp += str(c)
contexts = tmp
else:
contexts = contexts + "\n"
f.write(contexts)
f.close()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=0):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / (.0001 + self.count)
def __str__(self):
"""String representation for logging
"""
# for values that should be recorded exactly e.g. iteration number
if self.count == 0:
return str(self.val)
# for stats
return '%.4f (%.4f)' % (self.val, self.avg)
class LogCollector(object):
"""A collection of logging objects that can change from train to val"""
def __init__(self):
# to keep the order of logged variables deterministic
self.meters = OrderedDict()
def update(self, k, v, n=0):
# create a new meter if previously not recorded
if k not in self.meters:
self.meters[k] = AverageMeter()
self.meters[k].update(v, n)
def __str__(self):
"""Concatenate the meters in one log line
"""
s = ''
for i, (k, v) in enumerate(self.meters.items()):
if i > 0:
s += ' '
s += k + ' ' + str(v)
return s
def tb_log(self, tb_logger, prefix='', step=None):
"""Log using tensorboard
"""
for k, v in self.meters.items():
tb_logger.log_value(prefix + k, v.val, step=step)
def update_values(dict_from, dict_to):
for key, value in dict_from.items():
if isinstance(value, dict):
update_values(dict_from[key], dict_to[key])
elif value is not None:
dict_to[key] = dict_from[key]
return dict_to
def params_count(model):
count = 0
for p in model.parameters():
c = 1
for i in range(p.dim()):
c *= p.size(i)
count += c
return count
def collect_match(input):
"""change the model output to the match matrix"""
image_size = input.size(0)
text_size = input.size(1)
# match_v = torch.zeros(image_size, text_size, 1)
# match_v = match_v.view(image_size*text_size, 1)
input_ = nn.LogSoftmax(2)(input)
output = torch.index_select(input_, 2, Variable(torch.LongTensor([1])).cuda())
return output
def collect_neg(input):
""""collect the hard negative sample"""
if input.dim() != 2:
return ValueError
batch_size = input.size(0)
mask = Variable(torch.eye(batch_size)>0.5).cuda()
output = input.masked_fill_(mask, 0)
output_r = output.max(1)[0]
output_c = output.max(0)[0]
loss_n = torch.mean(output_r) + torch.mean(output_c)
return loss_n
def calcul_loss(scores, size, margin, max_violation=False):
diagonal = scores.diag().view(size, 1)
d1 = diagonal.expand_as(scores)
d2 = diagonal.t().expand_as(scores)
# compare every diagonal score to scores in its column
# caption retrieval
cost_s = (margin + scores - d1).clamp(min=0)
# compare every diagonal score to scores in its row
# image retrieval
cost_im = (margin + scores - d2).clamp(min=0)
mask = torch.eye(scores.size(0)) > .5
I = Variable(mask)
if torch.cuda.is_available():
I = I.cuda()
cost_s = cost_s.masked_fill_(I, 0)
cost_im = cost_im.masked_fill_(I, 0)
if max_violation:
cost_s = cost_s.max(1)[0]
cost_im = cost_im.max(0)[0]
return cost_s.sum() + cost_im.sum()
def acc_train(input):
predicted = input.squeeze().numpy()
batch_size = predicted.shape[0]
predicted[predicted > math.log(0.5)] = 1
predicted[predicted < math.log(0.5)] = 0
target = np.eye(batch_size)
recall = np.sum(predicted * target) / np.sum(target)
precision = np.sum(predicted * target) / np.sum(predicted)
acc = 1 - np.sum(abs(predicted - target)) / (target.shape[0] * target.shape[1])
return acc, recall, precision
def acc_i2t(input):
"""Computes the precision@k for the specified values of k of i2t"""
#input = collect_match(input).numpy()
image_size = input.shape[0]
ranks = np.zeros(image_size)
# ranks_ = np.zeros(image_size//5)
top1 = np.zeros(image_size)
for index in range(image_size):
inds = np.argsort(input[index])[::-1]
# Score
rank = 1e20
# index_ = index // 5
for i in range(5 * index, 5 * index + 5, 1):
tmp = np.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
if rank == 1e20:
print('error')
ranks[index] = rank
top1[index] = inds[0]
# Compute metrics
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
medr = np.floor(np.median(ranks)) + 1
meanr = ranks.mean() + 1
return (r1, r5, r10, medr, meanr), (ranks, top1)
def acc_t2i(input):
"""Computes the precision@k for the specified values of k of t2i"""
#input = collect_match(input).numpy()
image_size = input.shape[0]
ranks = np.zeros(5*image_size)
top1 = np.zeros(5*image_size)
# ranks_ = np.zeros(image_size // 5)
# --> (5N(caption), N(image))
input = input.T
for index in range(image_size):
for i in range(5):
inds = np.argsort(input[5 * index + i])[::-1]
ranks[5 * index + i] = np.where(inds == index)[0][0]
top1[5 * index + i] = inds[0]
# Compute metrics
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
medr = np.floor(np.median(ranks)) + 1
meanr = ranks.mean() + 1
return (r1, r5, r10, medr, meanr), (ranks, top1)
def shard_dis(images, captions, model, shard_size=128, lengths=None):
"""compute image-caption pairwise distance during validation and test"""
n_im_shard = (len(images) - 1) // shard_size + 1
n_cap_shard = (len(captions) - 1) // shard_size + 1
d = np.zeros((len(images), len(captions)))
for i in range(n_im_shard):
im_start, im_end = shard_size*i, min(shard_size*(i+1), len(images))
# print("======================")
# print("im_start:",im_start)
# print("im_end:",im_end)
for j in range(n_cap_shard):
# sys.stdout.write('\r>> shard_distance batch (%d,%d)' % (i,j))
cap_start, cap_end = shard_size * j, min(shard_size * (j + 1), len(captions))
im = Variable(torch.from_numpy(images[im_start:im_end]), volatile=True).float().cuda()
s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda()
l = lengths[cap_start:cap_end]
sim = model(im, s,l)
sim = sim.squeeze()
d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy()
sys.stdout.write('\n')
return d
def acc_i2t2(input):
"""Computes the precision@k for the specified values of k of i2t"""
#input = collect_match(input).numpy()
image_size = input.shape[0]
ranks = np.zeros(image_size)
top1 = np.zeros(image_size)
for index in range(image_size):
inds = np.argsort(input[index])[::-1]
# Score
rank = 1e20
for i in range(5 * index, 5 * index + 5, 1):
tmp = np.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
ranks[index] = rank
top1[index] = inds[0]
# Compute metrics
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
medr = np.floor(np.median(ranks)) + 1
meanr = ranks.mean() + 1
return (r1, r5, r10, medr, meanr), (ranks, top1)
def acc_t2i2(input):
"""Computes the precision@k for the specified values of k of t2i"""
#input = collect_match(input).numpy()
image_size = input.shape[0]
ranks = np.zeros(5*image_size)
top1 = np.zeros(5*image_size)
# --> (5N(caption), N(image))
input = input.T
for index in range(image_size):
for i in range(5):
inds = np.argsort(input[5 * index + i])[::-1]
ranks[5 * index + i] = np.where(inds == index)[0][0]
top1[5 * index + i] = inds[0]
# Compute metrics
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
medr = np.floor(np.median(ranks)) + 1
meanr = ranks.mean() + 1
return (r1, r5, r10, medr, meanr), (ranks, top1)
def shard_dis_reg(images, captions, model, shard_size=128, lengths=None):
"""compute image-caption pairwise distance during validation and test"""
n_im_shard = (len(images) - 1) // shard_size + 1
n_cap_shard = (len(captions) - 1) // shard_size + 1
d = np.zeros((len(images), len(captions)))
for i in range(len(images)):
# im_start, im_end = shard_size*i, min(shard_size*(i+1), len(images))
im_index = i
for j in range(n_cap_shard):
# sys.stdout.write('\r>> shard_distance batch (%d,%d)' % (i,j))
cap_start, cap_end = shard_size * j, min(shard_size * (j + 1), len(captions))
s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda()
im = Variable(torch.from_numpy(images[i]), volatile=True).float().unsqueeze(0).expand(len(s), 3, 256, 256).cuda()
l = lengths[cap_start:cap_end]
sim = model(im, s, l)[:, 1]
sim = sim.squeeze()
d[i, cap_start:cap_end] = sim.data.cpu().numpy()
sys.stdout.write('\n')
return d
def shard_dis_GAC(images, captions, model, shard_size=128, lengths=None):
"""compute image-caption pairwise distance during validation and test"""
# if torch.cuda.device_count() == 4:
# shard_size = 40
n_im_shard = (len(images) - 1) // shard_size + 1
n_cap_shard = (len(captions) - 1) // shard_size + 1
d = np.zeros((len(images), len(captions)))
all = []
for i in range(n_im_shard):
im_start, im_end = shard_size * i, min(shard_size*(i+1), len(images))
print("======================")
print("im_start:",im_start)
print("im_end:",im_end)
print("the len of captions:",len(captions))
for j in range(n_cap_shard):
# sys.stdout.write('\r>> shard_distance batch (%d,%d)' % (i,j))
cap_start, cap_end = shard_size * j, min(shard_size * (j + 1), len(captions))
# print("cap_start :", cap_start)
# print("cap_end :", cap_end)
with torch.no_grad():
im = Variable(torch.from_numpy(images[im_start:im_end])).float().cuda()
s = Variable(torch.from_numpy(captions[cap_start:cap_end])).cuda()
l = lengths[cap_start:cap_end]
t1 = time.time()
# calculate simularity
sim = model(im, s, l)
# visual_feature, text_feature = model(im, local_rep, local_adj, s, l)
# sim = cosine_sim(visual_feature, text_feature)
t2 = time.time()
all.append(t2-t1)
sim = sim.squeeze()
# print("sim shape : ", sim.shape)
d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy()
sys.stdout.write('\n')
print("infer time:",np.average(all))
return d
def save_checkpoint(state, need_save, is_best, filename, prefix='', model_name = None):
tries = 15
error = None
# deal with unstable I/O. Usually not necessary.
while tries:
try:
if need_save:
torch.save(state, prefix + filename)
if is_best:
torch.save(state, prefix + model_name +'_best.pth.tar')
except IOError as e:
error = e
tries -= 1
else:
break
print('model save {} failed, remaining {} trials'.format(filename, tries))
if not tries:
raise error
def adjust_learning_rate(options, optimizer, epoch):
"""Sets the learning rate to the initial LR
decayed by 10 every 30 epochs"""
for param_group in optimizer.param_groups:
lr = param_group['lr']
if epoch % options['optim']['lr_update_epoch'] == options['optim']['lr_update_epoch'] - 1:
lr = lr * options['optim']['lr_decay_param']
param_group['lr'] = lr
print("Current lr: {}".format(optimizer.state_dict()['param_groups'][0]['lr']))
def load_from_txt(filename, encoding="utf-8"):
f = open(filename,'r' ,encoding=encoding)
contexts = f.readlines()
return contexts