# -*- coding:UTF-8 -*- # ----------------------------------------------------------- # "BCAN++: Cross-modal Retrieval With Bidirectional Correct Attention Network" # Yang Liu, Hong Liu, Huaqiu Wang, Fanyang Meng, Mengyuan Liu* # # --------------------------------------------------------------- """Training script""" import itertools import os import time import shutil import torch import torch.nn as nn import numpy from torch.nn.utils.clip_grad import clip_grad_norm_ import logging import argparse import numpy as np import random from torch.utils.tensorboard import SummaryWriter from transformers import BertTokenizer from data import get_loaders from lib.vse import VSEModel from vocab import deserialize_vocab from model import SCAN, ContrastiveLoss from evaluation import AverageMeter, encode_data, LogCollector, i2t, t2i, shard_xattn, t2i_shuffle, i2t_shuffle, \ encode_data_vse, compute_sim, i2t_vse, t2i_vse def setup_seed(seed): np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) #cpu torch.cuda.manual_seed_all(seed) #并行gpu torch.backends.cudnn.deterministic = True #cpu/gpu结果一致 torch.backends.cudnn.benchmark = True #训练集变化不大时使训练加速 def logging_func(log_file, message): with open(log_file,'a') as f: f.write(message) f.close() def main(): #setup_seed(3245) setup_seed(3045) # Hyper Parameters parser = argparse.ArgumentParser() parser.add_argument('--data_path', default='../SCAN-master/data/', help='path to datasets') parser.add_argument('--data_name', default='f30k_precomp', help='{coco,f30k}_precomp') parser.add_argument('--vocab_path', default='../SCAN-master/vocab/', help='Path to saved vocabulary json files.') parser.add_argument('--margin', default=0.2, type=float, help='Rank loss margin.') parser.add_argument('--grad_clip', default=2.0, type=float, help='Gradient clipping threshold.') parser.add_argument('--num_epochs', default=25, type=int, help='Number of training epochs.') parser.add_argument('--batch_size', default=128, type=int, help='Size of a training mini-batch.') parser.add_argument('--word_dim', default=300, type=int, help='Dimensionality of the word embedding.') parser.add_argument('--embed_size', default=1024, type=int, help='Dimensionality of the joint embedding.') parser.add_argument('--num_layers', default=1, type=int, help='Number of GRU layers.') parser.add_argument('--optim', default='adam', type=str, help='the optimizer') parser.add_argument('--learning_rate', default=.0005, type=float, help='Initial learning rate.') parser.add_argument('--lr_update', default=15, type=int, help='Number of epochs to update the learning rate.') parser.add_argument('--workers', default=2, type=int, help='Number of data loader workers.') parser.add_argument('--val_step', default=500, type=int, help='Number of steps to run validation.') parser.add_argument('--log_step', default=100, type=int, help='Number of steps to print and record the log.') parser.add_argument('--logger_name', default='./runs/', help='Path to save Tensorboard log.') parser.add_argument('--model_name', default='./runs/bert_adam_bcan_gpo_vseinfty_ttt', help='Path to save the model.') parser.add_argument('--max_violation', action='store_true', help='Use max instead of sum in the rank loss.') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--img_dim', default=2048, type=int, help='Dimensionality of the image embedding.') parser.add_argument('--no_imgnorm', action='store_true', help='Do not normalize the image embeddings.') parser.add_argument('--no_txtnorm', action='store_true', help='Do not normalize the text embeddings.') parser.add_argument('--correct_type', default="prob", help='equal|prob') parser.add_argument('--precomp_enc_type', default="basic", help='basic|weight_norm') parser.add_argument('--bi_gru', action='store_true', default=True, help='Use bidirectional GRU.') parser.add_argument('--lambda_softmax', default=20., type=float, help='Attention softmax temperature.') parser.add_argument('--text_model_name', default="bert", type=str, help='') parser.add_argument('--text_model_pretrain', default="bert-base-uncased", type=str, help='') parser.add_argument('--text_model_word_dim', default=768, type=int, help='') parser.add_argument('--text_model_extraction_hidden_layer', default=6, type=int, help='') parser.add_argument('--text_model_pre_extracted', action='store_true', default=False, help='') parser.add_argument('--text_model_layers', default=0, type=int, help='') parser.add_argument('--text_model_dropout', default=0.1, type=float, help='') parser.add_argument('--backbone_path', type=str, default='', help='path to the pre-trained backbone net') parser.add_argument('--backbone_source', type=str, default='detector', help='the source of the backbone model, detector|imagenet') parser.add_argument('--vse_mean_warmup_epochs', type=int, default=1, help='The number of warmup epochs using mean vse loss') parser.add_argument('--reset_start_epoch', action='store_true', help='Whether restart the start epoch when load weights') parser.add_argument('--backbone_warmup_epochs', type=int, default=5, help='The number of epochs for warmup') parser.add_argument('--embedding_warmup_epochs', type=int, default=2, help='The number of epochs for warming up the embedding layers') parser.add_argument('--backbone_lr_factor', default=0.01, type=float, help='The lr factor for fine-tuning the backbone, it will be multiplied to the lr of ' 'the embedding layers') parser.add_argument('--input_scale_factor', type=float, default=1, help='The factor for scaling the input image') parser.add_argument("--trans_cfg", default='t_cfg.json', help="config file for image transformer") opt = parser.parse_known_args()[0] print(opt) logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) logging.info('train') os.makedirs(opt.logger_name, exist_ok=True) writer = SummaryWriter(log_dir=os.path.join(opt.logger_name, 'bert_adam_bcan_gpo_vseinfty_ttt')) # Load Vocabulary Wrapper vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) word2idx = vocab.word2idx opt.vocab_size = len(vocab) print(opt.vocab_size) # Load data loaders train_loader, val_loader = get_loaders( opt.data_name, vocab, opt.batch_size, opt.workers, opt) # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # vocab = tokenizer.vocab # opt.vocab_size = len(vocab) # print(opt.vocab_size) # train_loader, val_loader = get_loaders( # opt.data_path, opt.data_name, tokenizer, opt.batch_size, opt.workers, opt) # Construct the model model = SCAN(word2idx, opt) model.cuda() # params = list(model.V_self_atten_enhance.parameters()) # params += list(model.T_self_atten_enhance.parameters()) model = nn.DataParallel(model) model2 = VSEModel(opt) criterion = ContrastiveLoss(margin=opt.margin) #optimizer = torch.optim.Adam(model.params, lr=opt.learning_rate) optimizer = torch.optim.Adam(itertools.chain(model.parameters(), model2.params), lr=opt.learning_rate) #optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad, model.parameters()), lr=opt.learning_rate) #optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.9) best_rsum = 0 start_epoch = 0 # optionally resume from a checkpoint if opt.resume: if os.path.isfile(opt.resume): print("=> loading checkpoint '{}'".format(opt.resume)) checkpoint = torch.load(opt.resume) start_epoch = checkpoint['epoch'] + 1 best_rsum = checkpoint['best_rsum'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) # Eiters is used to show logs as the continuation of another # training # model.Eiters = checkpoint['Eiters'] print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})" .format(opt.resume, start_epoch, best_rsum)) else: print("=> no checkpoint found at '{}'".format(opt.resume)) # Train the Model for epoch in range(start_epoch, opt.num_epochs): print(opt.logger_name) print(opt.model_name) if not os.path.exists(opt.model_name): os.makedirs(opt.model_name) message = "epoch: %d, model name: %s\n" % (epoch, opt.model_name) log_file = os.path.join(opt.logger_name, "performance.log") logging_func(log_file, message) adjust_learning_rate(opt, optimizer, epoch) if epoch >= opt.vse_mean_warmup_epochs: opt.max_violation = True model2.set_max_violation(opt.max_violation) #model.set_max_violation(opt.max_violation) #validate(opt, val_loader, model, epoch, writer) # train for one epoch train(opt, train_loader, model, model2, criterion, optimizer, epoch, val_loader, writer)#union #train_bcan(opt, train_loader, model, criterion, optimizer, epoch, val_loader, writer)#bcan #train_vse(opt, train_loader, model, optimizer, epoch, val_loader, writer)#vse # evaluate on validation set #if(epoch % 30 == 0): rsum = validate(opt, val_loader, model, epoch, writer) #rsum = validate_vse(opt, val_loader, model, epoch, writer) # remember best R@ sum and save checkpoint is_best = rsum > best_rsum best_rsum = max(rsum, best_rsum) if not os.path.exists(opt.model_name): os.mkdir(opt.model_name) save_checkpoint({ 'epoch': epoch, 'model': model.state_dict(), 'best_rsum': best_rsum, 'opt': opt, 'optimizer': optimizer.state_dict(), }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/') class DataPrefetcher(): def __init__(self, loader): self.loader = iter(loader) self.stream = torch.cuda.Stream() #self.preload() def preload(self): try: self.images, self.img_lengths, self.captions, self.length, self.index = next(self.loader) except StopIteration: self.images, self.captions, self.length, self.index = None, None, None, None return with torch.cuda.stream(self.stream): self.images = self.images.cuda() self.captions = self.captions.cuda() self.img_lengths = self.img_lengths.cuda() def next(self): torch.cuda.current_stream().wait_stream(self.stream) self.preload() return self.images, self.img_lengths, self.captions, self.length, self.index def train(opt, train_loader, model, model2, criterion, optimizer, epoch, val_loader, writer): # average meters to record the training statistics batch_time = AverageMeter() data_time = AverageMeter() train_logger = LogCollector() run_time = 0 start_time = time.time() prefetcher = DataPrefetcher(train_loader) images, img_lengths, captions, lengths, index = prefetcher.next() i = 0 loss_all = 0 while images is not None: # switch to train mode model.train() model2.train_start() # measure data loading time model.logger = train_logger optimizer.zero_grad() # Update the model if torch.cuda.device_count() > 1: images = images.repeat(torch.cuda.device_count(), 1, 1) score = model(images, img_lengths, captions, lengths, index) loss = criterion(score) loss += model2.train_emb(images, captions, lengths, image_lengths=img_lengths) loss.backward() loss_all += loss.data.item() if opt.grad_clip > 0: clip_grad_norm_(model.parameters(), opt.grad_clip) clip_grad_norm_(model2.params, opt.grad_clip) optimizer.step() if (i + 1) % opt.log_step == 0: run_time += time.time() - start_time log = "epoch: %d; batch: %d/%d; loss: %.6f; time: %.4f" % (epoch, i, len(train_loader), loss.data.item(), run_time) print(log, flush=True) start_time = time.time() run_time = 0 #validate(opt, val_loader, model, epoch, writer) # validate at every val_step images, img_lengths, captions, lengths, index = prefetcher.next() i += 1 writer.add_scalar(f'Loss/train', loss_all/i, epoch) def train_bcan(opt, train_loader, model, criterion, optimizer, epoch, val_loader, writer): # average meters to record the training statistics batch_time = AverageMeter() data_time = AverageMeter() train_logger = LogCollector() run_time = 0 start_time = time.time() prefetcher = DataPrefetcher(train_loader) images, img_lengths, captions, lengths, index = prefetcher.next() i = 0 loss_all = 0 while images is not None: # switch to train mode model.train() #model2.train_start() # measure data loading time model.logger = train_logger optimizer.zero_grad() # Update the model if torch.cuda.device_count() > 1: images = images.repeat(torch.cuda.device_count(), 1, 1) score = model(images, img_lengths, captions, lengths, index) loss = criterion(score) #loss += model2.train_emb(images, captions, lengths, image_lengths=img_lengths) loss.backward() loss_all += loss.data.item() if opt.grad_clip > 0: clip_grad_norm_(model.parameters(), opt.grad_clip) #clip_grad_norm_(model2.params, opt.grad_clip) optimizer.step() if (i + 1) % opt.log_step == 0: run_time += time.time() - start_time log = "epoch: %d; batch: %d/%d; loss: %.6f; time: %.4f" % (epoch, i, len(train_loader), loss.data.item(), run_time) print(log, flush=True) start_time = time.time() run_time = 0 #validate(opt, val_loader, model, epoch, writer) # validate at every val_step images, img_lengths, captions, lengths, index = prefetcher.next() i += 1 writer.add_scalar(f'Loss/train', loss_all/i, epoch) def train_vse(opt, train_loader, model2, optimizer, epoch, val_loader, writer): # average meters to record the training statistics batch_time = AverageMeter() data_time = AverageMeter() train_logger = LogCollector() run_time = 0 start_time = time.time() prefetcher = DataPrefetcher(train_loader) images, img_lengths, captions, lengths, index = prefetcher.next() i = 0 loss_all = 0 while images is not None: # switch to train mode #model.train() model2.train_start() # measure data loading time model2.logger = train_logger optimizer.zero_grad() # Update the model # if torch.cuda.device_count() > 1: # images = images.repeat(torch.cuda.device_count(), 1, 1) # score = model(images, img_lengths, captions, lengths, index) # loss = criterion(score) loss = model2.train_emb(images, captions, lengths, image_lengths=img_lengths) loss.backward() loss_all += loss.data.item() if opt.grad_clip > 0: #clip_grad_norm_(model.parameters(), opt.grad_clip) clip_grad_norm_(model2.params, opt.grad_clip) optimizer.step() if (i + 1) % opt.log_step == 0: run_time += time.time() - start_time log = "epoch: %d; batch: %d/%d; loss: %.6f; time: %.4f" % (epoch, i, len(train_loader), loss.data.item(), run_time) print(log, flush=True) start_time = time.time() run_time = 0 #validate(opt, val_loader, model, epoch, writer) # validate at every val_step images, img_lengths, captions, lengths, index = prefetcher.next() i += 1 writer.add_scalar(f'Loss/train', loss_all/i, epoch) def validate(opt, val_loader, model, epoch, writer): # compute the encoding for all the validation images and captions img_embs, img_means, cap_embs, cap_lens, cap_means = encode_data( model, val_loader, opt.log_step, logging.info) print(img_embs.shape, cap_embs.shape) img_embs = numpy.array([img_embs[i] for i in range(0, len(img_embs), 5)]) #img_means = numpy.array([img_means[i] for i in range(0, len(img_means), 20)]) # tmp_cap_embs = [] # for i in range(0, len(cap_embs), 20): # for j in range(i,i+5,1): # tmp_cap_embs.append(cap_embs[j]) # cap_embs = numpy.array(tmp_cap_embs) # # tmp_cap_lens = [] # for i in range(0, len(cap_lens), 20): # for j in range(i, i + 5, 1): # tmp_cap_lens.append(cap_lens[j]) # cap_lens = numpy.array(tmp_cap_lens) # # tmp_cap_means = [] # for i in range(0, len(cap_means), 20): # for j in range(i, i + 5, 1): # tmp_cap_means.append(cap_means[j]) # cap_means = numpy.array(tmp_cap_means) print(img_embs.shape,img_means.shape) start = time.time() sims = shard_xattn(model, img_embs, img_means, cap_embs, cap_lens, cap_means, opt, shard_size=500) end = time.time() print("calculate similarity time:", end-start) print(sims.shape) # caption retrieval (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, cap_lens, sims) print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % (r1, r5, r10, medr, meanr)) # image retrieval (r1i, r5i, r10i, medri, meanr) = t2i( img_embs, cap_embs, cap_lens, sims) print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % (r1i, r5i, r10i, medri, meanr)) # sum of recalls to be used for early stopping currscore = r1 + r5 + r10 + r1i + r5i + r10i writer.add_scalar(f'i2t/r1', r1, epoch) writer.add_scalar(f'i2t/r10', r10, epoch) writer.add_scalar(f't2i/r1', r1i, epoch) writer.add_scalar(f't2i/r10', r10i, epoch) writer.add_scalar(f'rsum', currscore, epoch) return currscore def validate_vse(opt, val_loader, model, epoch, writer): logger = logging.getLogger(__name__) model.val_start() with torch.no_grad(): # compute the encoding for all the validation images and captions img_embs, cap_embs = encode_data_vse( model, val_loader, opt.log_step, logging.info, backbone=opt.precomp_enc_type == 'backbone') img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) start = time.time() sims = compute_sim(img_embs, cap_embs) end = time.time() logger.info("calculate similarity time: {}".format(end - start)) # caption retrieval npts = img_embs.shape[0] # (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, cap_lens, sims) (r1, r5, r10, medr, meanr) = i2t_vse(npts, sims) logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % (r1, r5, r10, medr, meanr)) # image retrieval # (r1i, r5i, r10i, medri, meanr) = t2i(img_embs, cap_embs, cap_lens, sims) (r1i, r5i, r10i, medri, meanr) = t2i_vse(npts, sims) logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % (r1i, r5i, r10i, medri, meanr)) # sum of recalls to be used for early stopping currscore = r1 + r5 + r10 + r1i + r5i + r10i logger.info('Current rsum is {}'.format(currscore)) currscore = r1 + r5 + r10 + r1i + r5i + r10i writer.add_scalar(f'i2t/r1', r1, epoch) writer.add_scalar(f'i2t/r10', r10, epoch) writer.add_scalar(f't2i/r1', r1i, epoch) writer.add_scalar(f't2i/r10', r10i, epoch) writer.add_scalar(f'rsum', currscore, epoch) return currscore def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix=''): tries = 15 error = None # deal with unstable I/O. Usually not necessary. while tries: try: torch.save(state, prefix + filename) if is_best: message = "--------save best model at epoch %d---------\n" % (state["epoch"] - 1) print(message, flush=True) log_file = os.path.join(prefix, "performance.log") logging_func(log_file, message) shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar') except IOError as e: error = e tries -= 1 else: break print('model save {} failed, remaining {} trials'.format(filename, tries)) if not tries: raise error def adjust_learning_rate(opt, optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update)) for param_group in optimizer.param_groups: param_group['lr'] = lr def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res if __name__ == '__main__': os.environ["CUDA_VISIBLE_DEVICES"] = "7" main()