import time import torch import numpy as np import sys from torch.autograd import Variable import tensorboard_logger as tb_logger import logging from torch.nn.utils.clip_grad import clip_grad_norm from model.utils import cosine_sim, cosine_similarity import utils def train(train_loader, model, optimizer, epoch, opt={}): # extract value grad_clip = opt['optim']['grad_clip'] max_violation = opt['optim']['max_violation'] margin = opt['optim']['margin'] loss_name = opt['model']['name'] + "_" + opt['dataset']['datatype'] print_freq = opt['logs']['print_freq'] # switch to train mode model.train() batch_time = utils.AverageMeter() data_time = utils.AverageMeter() train_logger = utils.LogCollector() end = time.time() params = list(model.parameters()) for i, train_data in enumerate(train_loader): images, captions, lengths, ids= train_data batch_size = images.size(0) # print("batch_size : ", batch_size) margin = float(margin) # measure data loading time data_time.update(time.time() - end) model.logger = train_logger input_visual = Variable(images) input_text = Variable(captions) if torch.cuda.is_available(): input_visual = input_visual.cuda() input_text = input_text.cuda() # visual_feature, text_feature = model(input_visual, input_local_rep, input_local_adj, input_text, lengths) # scores = cosine_sim(visual_feature, text_feature) # print("visual_feature shape : ", visual_feature.shape) scores = model(input_visual, input_text, lengths) # print("scores shape : ", scores.shape) torch.cuda.synchronize() loss = utils.calcul_loss(scores, input_visual.size(0), margin, max_violation=max_violation, ) if grad_clip > 0: clip_grad_norm(params, grad_clip) train_logger.update('L', loss.cpu().data.numpy()) optimizer.zero_grad() loss.backward() torch.cuda.synchronize() optimizer.step() torch.cuda.synchronize() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: logging.info( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f}\t' '{elog}\t' .format(epoch, i, len(train_loader), batch_time=batch_time, elog=str(train_logger))) utils.log_to_txt( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f}\t' '{elog}\t' .format(epoch, i, len(train_loader), batch_time=batch_time, elog=str(train_logger)), opt['logs']['ckpt_save_path']+ opt['model']['name'] + "_" + opt['dataset']['datatype'] +".txt" ) tb_logger.log_value('epoch', epoch) tb_logger.log_value('step', i) tb_logger.log_value('batch_time', batch_time.val) train_logger.tb_log(tb_logger) def validate(val_loader, model): model.eval() val_logger = utils.LogCollector() model.logger = val_logger start = time.time() # input_visual = np.zeros((len(val_loader.dataset), 3, 256, 256)) input_visual = np.zeros((len(val_loader.dataset), 3, 224, 224)) input_text = np.zeros((len(val_loader.dataset), 47), dtype=np.int64) input_text_lengeth = [0]*len(val_loader.dataset) for i, val_data in enumerate(val_loader): images, captions, lengths, ids = val_data for (id, img, cap, l) in zip(ids, (images.numpy().copy()), (captions.numpy().copy()), lengths): input_visual[id] = img input_text[id, :captions.size(1)] = cap input_text_lengeth[id] = l input_visual = np.array([input_visual[i] for i in range(0, len(input_visual), 5)]) d = utils.shard_dis_GAC(input_visual, input_text, model, lengths=input_text_lengeth) end = time.time() print("calculate similarity time:", end - start) (r1i, r5i, r10i, medri, meanri), _ = utils.acc_i2t2(d) logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % (r1i, r5i, r10i, medri, meanri)) (r1t, r5t, r10t, medrt, meanrt), _ = utils.acc_t2i2(d) logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % (r1t, r5t, r10t, medrt, meanrt)) currscore = (r1t + r5t + r10t + r1i + r5i + r10i)/6.0 all_score = "r1i:{} r5i:{} r10i:{} medri:{} meanri:{}\n r1t:{} r5t:{} r10t:{} medrt:{} meanrt:{}\n sum:{}\n ------\n".format( r1i, r5i, r10i, medri, meanri, r1t, r5t, r10t, medrt, meanrt, currscore ) tb_logger.log_value('r1i', r1i) tb_logger.log_value('r5i', r5i) tb_logger.log_value('r10i', r10i) tb_logger.log_value('medri', medri) tb_logger.log_value('meanri', meanri) tb_logger.log_value('r1t', r1t) tb_logger.log_value('r5t', r5t) tb_logger.log_value('r10t', r10t) tb_logger.log_value('medrt', medrt) tb_logger.log_value('meanrt', meanrt) tb_logger.log_value('rsum', currscore) return currscore, all_score def validate_test(val_loader, model): model.eval() val_logger = utils.LogCollector() model.logger = val_logger start = time.time() # input_visual = np.zeros((len(val_loader.dataset), 3, 256, 256)) input_visual = np.zeros((len(val_loader.dataset), 3, 224, 224)) input_text = np.zeros((len(val_loader.dataset), 47), dtype=np.int64) input_text_lengeth = [0] * len(val_loader.dataset) embed_start = time.time() for i, val_data in enumerate(val_loader): images, captions, lengths, ids = val_data for (id, img, cap, l) in zip(ids, (images.numpy().copy()), (captions.numpy().copy()), lengths): input_visual[id] = img input_text[id, :captions.size(1)] = cap input_text_lengeth[id] = l input_visual = np.array([input_visual[i] for i in range(0, len(input_visual), 5)]) embed_end = time.time() print("embedding time: {}".format(embed_end-embed_start)) d = utils.shard_dis_GAC(input_visual, input_text, model, lengths=input_text_lengeth) end = time.time() print("calculate similarity time:", end - start) return d