# ----------------------------------------------------------- # "BCAN++: Cross-modal Retrieval With Bidirectional Correct Attention Network" # Yang Liu, Hong Liu, Huaqiu Wang, Fanyang Meng, Mengyuan Liu* # # --------------------------------------------------------------- """BCAN model""" import copy import time import torch import torch.nn as nn import torch.nn.init import torchtext from torch.autograd import Variable from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence import numpy as np from transformers import BertTokenizer, BertModel, BertConfig def l1norm(X, dim, eps=1e-8): """L1-normalize columns of X """ norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps X = torch.div(X, norm) return X def l2norm(X, dim, eps=1e-8): """L2-normalize columns of X """ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True) norm = torch.sqrt(norm + eps) + eps X = torch.div(X, norm) return X def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False): """Returns cosine similarity between x1 and x2, computed along dim.""" w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim) w1 = torch.norm(x1, 2, dim, keepdim=keep_dim) w2 = torch.norm(x2, 2, dim, keepdim=keep_dim) if keep_dim: return w12 / (w1 * w2).clamp(min=eps) else: return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1) def func_attention(query, context, g_sim, opt, eps=1e-8): """ query: (batch, queryL, d) context: (batch, sourceL, d) opt: parameters """ batch_size, queryL, sourceL = context.size( 0), query.size(1), context.size(1) # Step 1: preassign attention # --> (batch, d, queryL) queryT = torch.transpose(query, 1, 2) # (batch, sourceL, d)(batch, d, queryL) attn = torch.bmm(context, queryT) attn = nn.LeakyReLU(0.1)(attn) attn = l2norm(attn, 2) # --> (batch, queryL, sourceL) attn = torch.transpose(attn, 1, 2).contiguous() # --> (batch*queryL, sourceL) attn = attn.view(batch_size * queryL, sourceL) attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax) # --> (batch, queryL, sourceL) attn = attn.view(batch_size, queryL, sourceL) # Step 2: identify irrelevant fragments # Learning an indicator function H, one for relevant, zero for irrelevant if opt.correct_type == 'equal': re_attn = correct_equal(attn, query, context, sourceL, g_sim) elif opt.correct_type == 'prob': re_attn = correct_prob(attn, query, context, sourceL, g_sim) # --> (batch, d, sourceL) contextT = torch.transpose(context, 1, 2) # --> (batch, sourceL, queryL) re_attnT = torch.transpose(re_attn, 1, 2).contiguous() # (batch x d x sourceL)(batch x sourceL x queryL) # --> (batch, d, queryL) weightedContext = torch.bmm(contextT, re_attnT) # --> (batch, queryL, d) weightedContext = torch.transpose(weightedContext, 1, 2) if torch.isnan(weightedContext).any(): print('ddd') return weightedContext, re_attn def correct_equal(attn, query, context, sourceL, g_sim): """ consider the confidence g(x) for each fragment as equal sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj attn: (batch, queryL, sourceL) """ # GCU process d = g_sim - 0.3 d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d) re_attn = d.unsqueeze(1).unsqueeze(2) * attn attn_sum = torch.sum(re_attn, dim=-1, keepdim=True) re_attn = re_attn / attn_sum cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True) cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1) re_attn1 = focal_equal(re_attn, query, context, sourceL) # LCU process cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True) cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos) delta = cos - cos1 delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta) re_attn2 = delta * re_attn1 attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True) re_attn2 = re_attn2 / attn_sum re_attn2 = focal_equal(re_attn2, query, context, sourceL) return re_attn2 def focal_equal(attn, query, context, sourceL): funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True) fattn = torch.where(funcF > 0, torch.ones_like(attn), torch.zeros_like(attn)) # Step 3: reassign attention tmp_attn = fattn * attn attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True) re_attn = tmp_attn / attn_sum return re_attn def correct_prob(attn, query, context, sourceL, g_sim): """ consider the confidence g(x) for each fragment as the sqrt of their similarity probability to the query fragment sigma_{j} (xi - xj)gj = sigma_{j} xi*gj - sigma_{j} xj*gj attn: (batch, queryL, sourceL) """ # GCU process d = g_sim - 0.3 d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d) re_attn = d.unsqueeze(1).unsqueeze(2) * attn attn_sum = torch.sum(re_attn, dim=-1, keepdim=True) re_attn = re_attn / attn_sum cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True) cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1) re_attn1 = focal_prob(re_attn, query, context, sourceL) # LCU process cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True) cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos) delta = cos - cos1 delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta) re_attn2 = delta * re_attn1 attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True) re_attn2 = re_attn2 / attn_sum re_attn2 = focal_prob(re_attn2, query, context, sourceL) return re_attn2 def focal_prob(attn, query, context, sourceL): batch_size, queryL, sourceL = context.size( 0), query.size(1), context.size(1) # -> (batch, queryL, sourceL, 1) xi = attn.unsqueeze(-1).contiguous() # -> (batch, queryL, 1, sourceL) xj = attn.unsqueeze(2).contiguous() # -> (batch, queryL, 1, sourceL) xj_confi = torch.sqrt(xj) xi = xi.view(batch_size * queryL, sourceL, 1) xj = xj.view(batch_size * queryL, 1, sourceL) xj_confi = xj_confi.view(batch_size * queryL, 1, sourceL) # -> (batch*queryL, sourceL, sourceL) term1 = torch.bmm(xi, xj_confi).clamp(min=1e-8) term2 = xj * xj_confi funcF = torch.sum(term1 - term2, dim=-1) # -> (batch*queryL, sourceL) funcF = funcF.view(batch_size, queryL, sourceL) fattn = torch.where(funcF > 0, torch.ones_like(attn), torch.zeros_like(attn)) # Step 3: reassign attention tmp_attn = fattn * attn attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True) re_attn = tmp_attn / attn_sum if torch.isnan(re_attn).any(): print("ddd") return re_attn def EncoderImage(data_name, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False): """A wrapper to image encoders. Chooses between an different encoders that uses precomputed image features. """ img_enc = EncoderImagePrecomp(img_dim, embed_size, no_imgnorm) return img_enc class EncoderImagePrecomp(nn.Module): def __init__(self, img_dim, embed_size, no_imgnorm=False): super(EncoderImagePrecomp, self).__init__() self.embed_size = embed_size self.no_imgnorm = no_imgnorm self.fc = nn.Linear(img_dim, embed_size) self.init_weights() def init_weights(self): """Xavier initialization for the fully connected layer """ r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features) self.fc.weight.data.uniform_(-r, r) self.fc.bias.data.fill_(0) def forward(self, images): """Extract image feature vectors.""" # assuming that the precomputed features are already l2-normalized #print(images, images.shape) features = self.fc(images) features_mean = torch.mean(features, 1) # normalize in the joint embedding space if not self.no_imgnorm: features = l2norm(features, dim=-1) return features, features_mean class EncoderTextBERT(nn.Module): def __init__(self, opt, order_embeddings=False, mean=True, post_transformer_layers=0): super().__init__() self.preextracted = opt.text_model_pre_extracted bert_config = BertConfig.from_pretrained(opt.text_model_pretrain, output_hidden_states=True, num_hidden_layers=opt.text_model_extraction_hidden_layer) bert_model = BertModel.from_pretrained(opt.text_model_pretrain, config=bert_config) self.order_embeddings = order_embeddings self.vocab_size = bert_model.config.vocab_size self.hidden_layer = opt.text_model_extraction_hidden_layer if not self.preextracted: self.tokenizer = BertTokenizer.from_pretrained(opt.text_model_pretrain) self.bert_model = bert_model self.word_embeddings = self.bert_model.get_input_embeddings() if post_transformer_layers > 0: transformer_layer = nn.TransformerEncoderLayer(d_model=opt.text_model_word_dim, nhead=4, dim_feedforward=2048, dropout=opt.text_model_dropout, activation='relu') self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=post_transformer_layers) self.post_transformer_layers = post_transformer_layers self.map = nn.Linear(opt.text_model_word_dim, opt.embed_size) self.mean = mean def forward(self, x, lengths): ''' x: tensor of indexes (LongTensor) obtained with tokenizer.encode() of size B x ? lengths: tensor of lengths (LongTensor) of size B ''' # print(x, x.shape) # print(lengths) if not self.preextracted or self.post_transformer_layers > 0: max_len = max(lengths) attention_mask = torch.ones(x.shape[0], max_len) for e, l in zip(attention_mask, lengths): e[l:] = 0 attention_mask = attention_mask.to(x.device) if self.preextracted: outputs = x else: outputs = self.bert_model(x, attention_mask=attention_mask) outputs = outputs[2][-1] if self.post_transformer_layers > 0: outputs = outputs.permute(1, 0, 2) outputs = self.transformer_encoder(outputs, src_key_padding_mask=(attention_mask - 1).bool()) outputs = outputs.permute(1, 0, 2) if self.mean: #x = outputs.mean(dim=1) x = torch.mean(outputs, 1) else: x = outputs[:, 0, :] # from the last layer take only the first word out = self.map(x) outputs = self.map(outputs) # normalization in the joint embedding space # out = l2norm(out) # take absolute value, used by order embeddings if self.order_embeddings: out = torch.abs(out) #print(outputs.shape, out.shape) return outputs, lengths, out def get_finetuning_params(self): return list(self.bert_model.parameters()) def encoder_text(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru=False, no_txtnorm=False): txt_enc = EncoderText(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru, no_txtnorm) return txt_enc class EncoderText(nn.Module): def __init__(self, word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru=False, no_txtnorm=False): super(EncoderText, self).__init__() self.embed_size = embed_size self.no_txtnorm = no_txtnorm # word embedding self.embed = nn.Embedding(vocab_size, word_dim) # caption embedding self.use_bi_gru = use_bi_gru self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru) self.init_weights(word2idx) def init_weights(self, word2idx): # self.embed.weight.data.uniform_(-0.1, 0.1) wemb = torchtext.vocab.GloVe(cache=".vector_cache") # quick-and-dirty trick to improve word-hit rate missing_words = [] for word, idx in word2idx.items(): if word not in wemb.stoi: word = word.replace('-', '').replace('.', '').replace("'", '') if '/' in word: word = word.split('/')[0] if word in wemb.stoi: self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]] else: missing_words.append(word) print('Words: {}/{} found in vocabulary; {} words missing'.format( len(word2idx) - len(missing_words), len(word2idx), len(missing_words))) def forward(self, x, lengths): """Handles variable size captions """ # Embed word ids to vectors #print(x, x.shape, lengths, len(lengths)) x = self.embed(x) packed = pack_padded_sequence(x, lengths, batch_first=True) if torch.cuda.device_count() > 1: self.rnn.flatten_parameters() # Forward propagate RNN out, _ = self.rnn(packed) #print(out.dtype, out.shape) #print("---") # Reshape *final* output to (batch_size, hidden_size) padded = pad_packed_sequence(out, batch_first=True) cap_emb, cap_len = padded if self.use_bi_gru: cap_emb = (cap_emb[:, :, :int(cap_emb.size(2) / 2)] + cap_emb[:, :, int(cap_emb.size(2) / 2):]) / 2 cap_emb_mean = torch.mean(cap_emb, 1) # normalization in the joint embedding space if not self.no_txtnorm: cap_emb = l2norm(cap_emb, dim=-1) cap_emb_mean = l2norm(cap_emb_mean, dim=1) #print(cap_emb.shape, cap_emb_mean.shape) return cap_emb, cap_len, cap_emb_mean ''' Visual self-attention module ''' class V_single_modal_atten(nn.Module): """ Single Visual Modal Attention Network. """ def __init__(self, image_dim, embed_dim, dropout_rate=0.4, img_region_num=36): """ param image_dim: dim of visual feature param embed_dim: dim of embedding space """ super(V_single_modal_atten, self).__init__() self.fc1 = nn.Linear(image_dim, embed_dim) # embed visual feature to common space self.fc2 = nn.Linear(image_dim, embed_dim) # embed memory to common space self.fc2_2 = nn.Linear(embed_dim, embed_dim) self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights self.fc4 = nn.Linear(image_dim, embed_dim) # embed attentive feature to common space self.embedding_1 = nn.Sequential(self.fc1, nn.BatchNorm1d(img_region_num), nn.Tanh(), nn.Dropout(dropout_rate)) self.embedding_2 = nn.Sequential(self.fc2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate)) self.embedding_2_2 = nn.Sequential(self.fc2_2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate)) self.embedding_3 = nn.Sequential(self.fc3) self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights def forward(self, v_t, m_v): """ Forward propagation. :param v_t: encoded images, shape: (batch_size, num_regions, image_dim) :param m_v: previous visual memory, shape: (batch_size, image_dim) :return: attention weighted encoding, weights """ W_v = self.embedding_1(v_t) if m_v.size()[-1] == v_t.size()[-1]: W_v_m = self.embedding_2(m_v) else: W_v_m = self.embedding_2_2(m_v) W_v_m = W_v_m.unsqueeze(1).repeat(1, W_v.size()[1], 1) h_v = W_v.mul(W_v_m) a_v = self.embedding_3(h_v) a_v = a_v.squeeze(2) weights = self.softmax(a_v) v_att = ((weights.unsqueeze(2) * v_t)).sum(dim=1) # l2 norm v_att = l2norm(v_att, -1) return v_att, weights class T_single_modal_atten(nn.Module): """ Single Textual Modal Attention Network. """ def __init__(self, embed_dim, dropout_rate=0.4): """ param image_dim: dim of visual feature param embed_dim: dim of embedding space """ super(T_single_modal_atten, self).__init__() self.fc1 = nn.Linear(embed_dim, embed_dim) # embed visual feature to common space self.fc2 = nn.Linear(embed_dim, embed_dim) # embed memory to common space self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights self.embedding_1 = nn.Sequential(self.fc1, nn.Tanh(), nn.Dropout(dropout_rate)) self.embedding_2 = nn.Sequential(self.fc2, nn.Tanh(), nn.Dropout(dropout_rate)) self.embedding_3 = nn.Sequential(self.fc3) self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights def forward(self, u_t, m_u): """ Forward propagation. :param v_t: encoded images, shape: (batch_size, num_regions, image_dim) :param m_v: previous visual memory, shape: (batch_size, image_dim) :return: attention weighted encoding, weights """ W_u = self.embedding_1(u_t) W_u_m = self.embedding_2(m_u) W_u_m = W_u_m.unsqueeze(1).repeat(1, W_u.size()[1], 1) h_u = W_u.mul(W_u_m) a_u = self.embedding_3(h_u) a_u = a_u.squeeze(2) weights = self.softmax(a_u) u_att = ((weights.unsqueeze(2) * u_t)).sum(dim=1) # l2 norm u_att = l2norm(u_att, -1) return u_att, weights class ContrastiveLoss(nn.Module): """ Compute contrastive loss """ def __init__(self, margin=0): super(ContrastiveLoss, self).__init__() self.margin = margin def forward(self, scores): # compute image-sentence score matrix diagonal = scores.diag().view(-1, 1) d1 = diagonal.expand_as(scores) d2 = diagonal.t().expand_as(scores) # compare every diagonal score to scores in its column # caption retrieval cost_s = (self.margin + scores - d1).clamp(min=0) # compare every diagonal score to scores in its row # image retrieval cost_im = (self.margin + scores - d2).clamp(min=0) # clear diagonals mask = torch.eye(scores.size(0)) > .5 I = Variable(mask) if torch.cuda.is_available(): I = I.cuda() cost_s = cost_s.masked_fill_(I, 0) cost_im = cost_im.masked_fill_(I, 0) # keep the maximum violating negative for each query cost_s = cost_s.max(1)[0] cost_im = cost_im.max(0)[0] return cost_s.sum() + cost_im.sum() class SCAN2(nn.Module): """ Stacked Cross Attention Network (SCAN) model """ def __init__(self, word2idx, opt): super(SCAN2, self).__init__() # Build Models self.grad_clip = opt.grad_clip self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size, precomp_enc_type=opt.precomp_enc_type, no_imgnorm=opt.no_imgnorm) self.txt_enc = encoder_text(word2idx, opt.vocab_size, opt.word_dim, opt.embed_size, opt.num_layers, use_bi_gru=True, no_txtnorm=opt.no_txtnorm) #self.txt_enc = EncoderTextBERT(opt, post_transformer_layers=opt.text_model_layers) self.V_self_atten_enhance = V_single_modal_atten(opt.embed_size, opt.embed_size) self.T_self_atten_enhance = T_single_modal_atten(opt.embed_size) self.opt = opt self.Eiters = 0 def forward_emb(self, images, captions, lengths): """Compute the image and caption embeddings """ # Forward img_emb, img_mean = self.img_enc(images) #print(img_emb.shape,img_mean.shape) cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths) img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean) cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean) return img_emb, img_mean, cap_emb, cap_lens, cap_mean def txt_emb(self, captions, lengths): """Compute the caption embeddings """ # Forward cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths) cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean) return cap_emb, cap_lens, cap_mean def image_emb(self, images): """Compute the image embeddings """ # Forward img_emb, img_mean = self.img_enc(images) img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean) return img_emb, img_mean def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs): """Compute the loss given pairs of image and caption embeddings """ scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean) return scores def forward(self, images, captions, lengths, ids=None, *args): # compute the embeddings lengths = lengths.cpu().numpy().tolist() img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, captions, lengths) scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean) return scores def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean): similarities = [] n_image = images.size(0) n_caption = captions.size(0) g_sims = cap_mean.mm(img_mean.t()) now = time.time() for i in range(n_caption): # Get the i-th text description n_word = cap_lens[i] g_sim = g_sims[i] cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous() # --> (n_image, n_word, d) cap_i_expand = cap_i.repeat(n_image, 1, 1) # t2i process # weiContext: (n_image, n_word, d) weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt) t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2) t2i_sim = t2i_sim.mean(dim=1, keepdim=True) # i2t process # weiContext: (n_image, n_word, d) weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt) i2t_sim = cosine_similarity(images, weiContext, dim=2) i2t_sim = i2t_sim.mean(dim=1, keepdim=True) # Overall similarity for image and text sim = t2i_sim + i2t_sim similarities.append(sim) # (n_image, n_caption) similarities = torch.cat(similarities, 1) if self.training: similarities = similarities.transpose(0, 1) #print('Time:{:.4f}'.format(time.time() - now)) return similarities