Graduation_Project/WZM/engine.py

185 lines
6.2 KiB
Python
Raw Permalink Normal View History

2024-06-24 18:15:10 +08:00
import time
import torch
import numpy as np
import sys
from torch.autograd import Variable
import tensorboard_logger as tb_logger
import logging
from torch.nn.utils.clip_grad import clip_grad_norm
from model.utils import cosine_sim, cosine_similarity
import utils
def train(train_loader, model, optimizer, epoch, opt={}):
# extract value
grad_clip = opt['optim']['grad_clip']
max_violation = opt['optim']['max_violation']
margin = opt['optim']['margin']
loss_name = opt['model']['name'] + "_" + opt['dataset']['datatype']
print_freq = opt['logs']['print_freq']
# switch to train mode
model.train()
batch_time = utils.AverageMeter()
data_time = utils.AverageMeter()
train_logger = utils.LogCollector()
end = time.time()
params = list(model.parameters())
for i, train_data in enumerate(train_loader):
images, captions, lengths, ids= train_data
batch_size = images.size(0)
# print("batch_size : ", batch_size)
margin = float(margin)
# measure data loading time
data_time.update(time.time() - end)
model.logger = train_logger
input_visual = Variable(images)
input_text = Variable(captions)
if torch.cuda.is_available():
input_visual = input_visual.cuda()
input_text = input_text.cuda()
# visual_feature, text_feature = model(input_visual, input_local_rep, input_local_adj, input_text, lengths)
# scores = cosine_sim(visual_feature, text_feature)
# print("visual_feature shape : ", visual_feature.shape)
scores = model(input_visual, input_text, lengths)
# print("scores shape : ", scores.shape)
torch.cuda.synchronize()
loss = utils.calcul_loss(scores, input_visual.size(0), margin, max_violation=max_violation, )
if grad_clip > 0:
clip_grad_norm(params, grad_clip)
train_logger.update('L', loss.cpu().data.numpy())
optimizer.zero_grad()
loss.backward()
torch.cuda.synchronize()
optimizer.step()
torch.cuda.synchronize()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0:
logging.info(
'Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f}\t'
'{elog}\t'
.format(epoch, i, len(train_loader),
batch_time=batch_time,
elog=str(train_logger)))
utils.log_to_txt(
'Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f}\t'
'{elog}\t'
.format(epoch, i, len(train_loader),
batch_time=batch_time,
elog=str(train_logger)),
opt['logs']['ckpt_save_path']+ opt['model']['name'] + "_" + opt['dataset']['datatype'] +".txt"
)
tb_logger.log_value('epoch', epoch)
tb_logger.log_value('step', i)
tb_logger.log_value('batch_time', batch_time.val)
train_logger.tb_log(tb_logger)
def validate(val_loader, model):
model.eval()
val_logger = utils.LogCollector()
model.logger = val_logger
start = time.time()
# input_visual = np.zeros((len(val_loader.dataset), 3, 256, 256))
input_visual = np.zeros((len(val_loader.dataset), 3, 224, 224))
input_text = np.zeros((len(val_loader.dataset), 47), dtype=np.int64)
input_text_lengeth = [0]*len(val_loader.dataset)
for i, val_data in enumerate(val_loader):
images, captions, lengths, ids = val_data
for (id, img, cap, l) in zip(ids, (images.numpy().copy()), (captions.numpy().copy()), lengths):
input_visual[id] = img
input_text[id, :captions.size(1)] = cap
input_text_lengeth[id] = l
input_visual = np.array([input_visual[i] for i in range(0, len(input_visual), 5)])
d = utils.shard_dis_GAC(input_visual, input_text, model, lengths=input_text_lengeth)
end = time.time()
print("calculate similarity time:", end - start)
(r1i, r5i, r10i, medri, meanri), _ = utils.acc_i2t2(d)
logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
(r1i, r5i, r10i, medri, meanri))
(r1t, r5t, r10t, medrt, meanrt), _ = utils.acc_t2i2(d)
logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
(r1t, r5t, r10t, medrt, meanrt))
currscore = (r1t + r5t + r10t + r1i + r5i + r10i)/6.0
all_score = "r1i:{} r5i:{} r10i:{} medri:{} meanri:{}\n r1t:{} r5t:{} r10t:{} medrt:{} meanrt:{}\n sum:{}\n ------\n".format(
r1i, r5i, r10i, medri, meanri, r1t, r5t, r10t, medrt, meanrt, currscore
)
tb_logger.log_value('r1i', r1i)
tb_logger.log_value('r5i', r5i)
tb_logger.log_value('r10i', r10i)
tb_logger.log_value('medri', medri)
tb_logger.log_value('meanri', meanri)
tb_logger.log_value('r1t', r1t)
tb_logger.log_value('r5t', r5t)
tb_logger.log_value('r10t', r10t)
tb_logger.log_value('medrt', medrt)
tb_logger.log_value('meanrt', meanrt)
tb_logger.log_value('rsum', currscore)
return currscore, all_score
def validate_test(val_loader, model):
model.eval()
val_logger = utils.LogCollector()
model.logger = val_logger
start = time.time()
# input_visual = np.zeros((len(val_loader.dataset), 3, 256, 256))
input_visual = np.zeros((len(val_loader.dataset), 3, 224, 224))
input_text = np.zeros((len(val_loader.dataset), 47), dtype=np.int64)
input_text_lengeth = [0] * len(val_loader.dataset)
embed_start = time.time()
for i, val_data in enumerate(val_loader):
images, captions, lengths, ids = val_data
for (id, img, cap, l) in zip(ids, (images.numpy().copy()), (captions.numpy().copy()), lengths):
input_visual[id] = img
input_text[id, :captions.size(1)] = cap
input_text_lengeth[id] = l
input_visual = np.array([input_visual[i] for i in range(0, len(input_visual), 5)])
embed_end = time.time()
print("embedding time: {}".format(embed_end-embed_start))
d = utils.shard_dis_GAC(input_visual, input_text, model, lengths=input_text_lengeth)
end = time.time()
print("calculate similarity time:", end - start)
return d