633 lines
23 KiB
Python
633 lines
23 KiB
Python
# -----------------------------------------------------------
|
|
# "BCAN++: Cross-modal Retrieval With Bidirectional Correct Attention Network"
|
|
# Yang Liu, Hong Liu, Huaqiu Wang, Fanyang Meng, Mengyuan Liu*
|
|
#
|
|
# ---------------------------------------------------------------
|
|
"""BCAN model"""
|
|
import copy
|
|
import time
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.init
|
|
import torchtext
|
|
from torch.autograd import Variable
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
import numpy as np
|
|
from transformers import BertTokenizer, BertModel, BertConfig
|
|
|
|
|
|
def l1norm(X, dim, eps=1e-8):
|
|
"""L1-normalize columns of X
|
|
"""
|
|
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
|
|
X = torch.div(X, norm)
|
|
return X
|
|
|
|
|
|
def l2norm(X, dim, eps=1e-8):
|
|
"""L2-normalize columns of X
|
|
"""
|
|
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True)
|
|
norm = torch.sqrt(norm + eps) + eps
|
|
X = torch.div(X, norm)
|
|
return X
|
|
|
|
|
|
def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False):
|
|
"""Returns cosine similarity between x1 and x2, computed along dim."""
|
|
w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim)
|
|
w1 = torch.norm(x1, 2, dim, keepdim=keep_dim)
|
|
w2 = torch.norm(x2, 2, dim, keepdim=keep_dim)
|
|
if keep_dim:
|
|
return w12 / (w1 * w2).clamp(min=eps)
|
|
else:
|
|
return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1)
|
|
|
|
|
|
def func_attention(query, context, g_sim, opt, eps=1e-8):
|
|
"""
|
|
query: (batch, queryL, d)
|
|
context: (batch, sourceL, d)
|
|
opt: parameters
|
|
"""
|
|
batch_size, queryL, sourceL = context.size(
|
|
0), query.size(1), context.size(1)
|
|
|
|
# Step 1: preassign attention
|
|
# --> (batch, d, queryL)
|
|
queryT = torch.transpose(query, 1, 2)
|
|
|
|
# (batch, sourceL, d)(batch, d, queryL)
|
|
attn = torch.bmm(context, queryT)
|
|
attn = nn.LeakyReLU(0.1)(attn)
|
|
attn = l2norm(attn, 2)
|
|
|
|
# --> (batch, queryL, sourceL)
|
|
attn = torch.transpose(attn, 1, 2).contiguous()
|
|
# --> (batch*queryL, sourceL)
|
|
attn = attn.view(batch_size * queryL, sourceL)
|
|
attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax)
|
|
# --> (batch, queryL, sourceL)
|
|
attn = attn.view(batch_size, queryL, sourceL)
|
|
|
|
# Step 2: identify irrelevant fragments
|
|
# Learning an indicator function H, one for relevant, zero for irrelevant
|
|
if opt.correct_type == 'equal':
|
|
re_attn = correct_equal(attn, query, context, sourceL, g_sim)
|
|
elif opt.correct_type == 'prob':
|
|
re_attn = correct_prob(attn, query, context, sourceL, g_sim)
|
|
# --> (batch, d, sourceL)
|
|
contextT = torch.transpose(context, 1, 2)
|
|
# --> (batch, sourceL, queryL)
|
|
re_attnT = torch.transpose(re_attn, 1, 2).contiguous()
|
|
# (batch x d x sourceL)(batch x sourceL x queryL)
|
|
# --> (batch, d, queryL)
|
|
weightedContext = torch.bmm(contextT, re_attnT)
|
|
|
|
# --> (batch, queryL, d)
|
|
weightedContext = torch.transpose(weightedContext, 1, 2)
|
|
|
|
if torch.isnan(weightedContext).any():
|
|
print('ddd')
|
|
return weightedContext, re_attn
|
|
|
|
|
|
def correct_equal(attn, query, context, sourceL, g_sim):
|
|
"""
|
|
consider the confidence g(x) for each fragment as equal
|
|
sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj
|
|
attn: (batch, queryL, sourceL)
|
|
"""
|
|
# GCU process
|
|
d = g_sim - 0.3
|
|
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
|
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
|
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
|
re_attn = re_attn / attn_sum
|
|
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
|
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
|
re_attn1 = focal_equal(re_attn, query, context, sourceL)
|
|
|
|
# LCU process
|
|
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
|
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
|
delta = cos - cos1
|
|
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
|
re_attn2 = delta * re_attn1
|
|
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
|
re_attn2 = re_attn2 / attn_sum
|
|
re_attn2 = focal_equal(re_attn2, query, context, sourceL)
|
|
return re_attn2
|
|
|
|
|
|
def focal_equal(attn, query, context, sourceL):
|
|
funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True)
|
|
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
|
torch.zeros_like(attn))
|
|
|
|
# Step 3: reassign attention
|
|
tmp_attn = fattn * attn
|
|
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
|
re_attn = tmp_attn / attn_sum
|
|
|
|
return re_attn
|
|
|
|
|
|
def correct_prob(attn, query, context, sourceL, g_sim):
|
|
"""
|
|
consider the confidence g(x) for each fragment as the sqrt
|
|
of their similarity probability to the query fragment
|
|
sigma_{j} (xi - xj)gj = sigma_{j} xi*gj - sigma_{j} xj*gj
|
|
attn: (batch, queryL, sourceL)
|
|
"""
|
|
# GCU process
|
|
d = g_sim - 0.3
|
|
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
|
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
|
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
|
re_attn = re_attn / attn_sum
|
|
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
|
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
|
re_attn1 = focal_prob(re_attn, query, context, sourceL)
|
|
|
|
# LCU process
|
|
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
|
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
|
delta = cos - cos1
|
|
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
|
re_attn2 = delta * re_attn1
|
|
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
|
re_attn2 = re_attn2 / attn_sum
|
|
re_attn2 = focal_prob(re_attn2, query, context, sourceL)
|
|
return re_attn2
|
|
|
|
|
|
def focal_prob(attn, query, context, sourceL):
|
|
batch_size, queryL, sourceL = context.size(
|
|
0), query.size(1), context.size(1)
|
|
|
|
# -> (batch, queryL, sourceL, 1)
|
|
xi = attn.unsqueeze(-1).contiguous()
|
|
# -> (batch, queryL, 1, sourceL)
|
|
xj = attn.unsqueeze(2).contiguous()
|
|
# -> (batch, queryL, 1, sourceL)
|
|
xj_confi = torch.sqrt(xj)
|
|
|
|
xi = xi.view(batch_size * queryL, sourceL, 1)
|
|
xj = xj.view(batch_size * queryL, 1, sourceL)
|
|
xj_confi = xj_confi.view(batch_size * queryL, 1, sourceL)
|
|
|
|
# -> (batch*queryL, sourceL, sourceL)
|
|
term1 = torch.bmm(xi, xj_confi).clamp(min=1e-8)
|
|
term2 = xj * xj_confi
|
|
funcF = torch.sum(term1 - term2, dim=-1) # -> (batch*queryL, sourceL)
|
|
funcF = funcF.view(batch_size, queryL, sourceL)
|
|
|
|
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
|
torch.zeros_like(attn))
|
|
|
|
# Step 3: reassign attention
|
|
tmp_attn = fattn * attn
|
|
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
|
re_attn = tmp_attn / attn_sum
|
|
|
|
if torch.isnan(re_attn).any():
|
|
print("ddd")
|
|
return re_attn
|
|
|
|
|
|
def EncoderImage(data_name, img_dim, embed_size, precomp_enc_type='basic',
|
|
no_imgnorm=False):
|
|
"""A wrapper to image encoders. Chooses between an different encoders
|
|
that uses precomputed image features.
|
|
"""
|
|
img_enc = EncoderImagePrecomp(img_dim, embed_size, no_imgnorm)
|
|
|
|
return img_enc
|
|
|
|
|
|
class EncoderImagePrecomp(nn.Module):
|
|
|
|
def __init__(self, img_dim, embed_size, no_imgnorm=False):
|
|
super(EncoderImagePrecomp, self).__init__()
|
|
self.embed_size = embed_size
|
|
self.no_imgnorm = no_imgnorm
|
|
self.fc = nn.Linear(img_dim, embed_size)
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
"""Xavier initialization for the fully connected layer
|
|
"""
|
|
r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
|
|
self.fc.weight.data.uniform_(-r, r)
|
|
self.fc.bias.data.fill_(0)
|
|
|
|
def forward(self, images):
|
|
"""Extract image feature vectors."""
|
|
# assuming that the precomputed features are already l2-normalized
|
|
#print(images, images.shape)
|
|
features = self.fc(images)
|
|
#features_mean = torch.mean(features, 1)
|
|
features_mean = torch.max(features, 1)[0]
|
|
# normalize in the joint embedding space
|
|
if not self.no_imgnorm:
|
|
features = l2norm(features, dim=-1)
|
|
|
|
return features, features_mean
|
|
|
|
class EncoderTextBERT(nn.Module):
|
|
def __init__(self, opt, order_embeddings=False, mean=True, post_transformer_layers=0):
|
|
super().__init__()
|
|
self.preextracted = opt.text_model_pre_extracted
|
|
bert_config = BertConfig.from_pretrained(opt.text_model_pretrain,
|
|
output_hidden_states=True,
|
|
num_hidden_layers=opt.text_model_extraction_hidden_layer)
|
|
bert_model = BertModel.from_pretrained(opt.text_model_pretrain, config=bert_config)
|
|
self.order_embeddings = order_embeddings
|
|
self.vocab_size = bert_model.config.vocab_size
|
|
self.hidden_layer = opt.text_model_extraction_hidden_layer
|
|
if not self.preextracted:
|
|
self.tokenizer = BertTokenizer.from_pretrained(opt.text_model_pretrain)
|
|
self.bert_model = bert_model
|
|
self.word_embeddings = self.bert_model.get_input_embeddings()
|
|
if post_transformer_layers > 0:
|
|
transformer_layer = nn.TransformerEncoderLayer(d_model=opt.text_model_word_dim, nhead=4,
|
|
dim_feedforward=2048,
|
|
dropout=opt.text_model_dropout, activation='relu')
|
|
self.transformer_encoder = nn.TransformerEncoder(transformer_layer,
|
|
num_layers=post_transformer_layers)
|
|
self.post_transformer_layers = post_transformer_layers
|
|
self.map = nn.Linear(opt.text_model_word_dim, opt.embed_size)
|
|
self.mean = mean
|
|
|
|
def forward(self, x, lengths):
|
|
'''
|
|
x: tensor of indexes (LongTensor) obtained with tokenizer.encode() of size B x ?
|
|
lengths: tensor of lengths (LongTensor) of size B
|
|
'''
|
|
# print(x, x.shape)
|
|
# print(lengths)
|
|
if not self.preextracted or self.post_transformer_layers > 0:
|
|
max_len = max(lengths)
|
|
attention_mask = torch.ones(x.shape[0], max_len)
|
|
for e, l in zip(attention_mask, lengths):
|
|
e[l:] = 0
|
|
attention_mask = attention_mask.to(x.device)
|
|
|
|
if self.preextracted:
|
|
outputs = x
|
|
else:
|
|
outputs = self.bert_model(x, attention_mask=attention_mask)
|
|
outputs = outputs[2][-1]
|
|
|
|
if self.post_transformer_layers > 0:
|
|
outputs = outputs.permute(1, 0, 2)
|
|
outputs = self.transformer_encoder(outputs, src_key_padding_mask=(attention_mask - 1).bool())
|
|
outputs = outputs.permute(1, 0, 2)
|
|
if self.mean:
|
|
#x = outputs.mean(dim=1)
|
|
x = torch.mean(outputs, 1)
|
|
else:
|
|
x = outputs[:, 0, :] # from the last layer take only the first word
|
|
|
|
out = self.map(x)
|
|
outputs = self.map(outputs)
|
|
|
|
# normalization in the joint embedding space
|
|
# out = l2norm(out)
|
|
|
|
# take absolute value, used by order embeddings
|
|
if self.order_embeddings:
|
|
out = torch.abs(out)
|
|
#print(outputs.shape, out.shape)
|
|
return outputs, lengths, out
|
|
|
|
def get_finetuning_params(self):
|
|
return list(self.bert_model.parameters())
|
|
|
|
def encoder_text(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru=False, no_txtnorm=False):
|
|
txt_enc = EncoderText(word2idx, vocab_size, word_dim, embed_size, num_layers, use_bi_gru, no_txtnorm)
|
|
|
|
return txt_enc
|
|
|
|
|
|
class EncoderText(nn.Module):
|
|
|
|
def __init__(self, word2idx, vocab_size, word_dim, embed_size, num_layers,
|
|
use_bi_gru=False, no_txtnorm=False):
|
|
super(EncoderText, self).__init__()
|
|
self.embed_size = embed_size
|
|
self.no_txtnorm = no_txtnorm
|
|
|
|
# word embedding
|
|
self.embed = nn.Embedding(vocab_size, word_dim)
|
|
|
|
# caption embedding
|
|
self.use_bi_gru = use_bi_gru
|
|
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
|
|
|
|
self.init_weights(word2idx)
|
|
|
|
def init_weights(self, word2idx):
|
|
# self.embed.weight.data.uniform_(-0.1, 0.1)
|
|
|
|
wemb = torchtext.vocab.GloVe(cache=".vector_cache")
|
|
|
|
# quick-and-dirty trick to improve word-hit rate
|
|
missing_words = []
|
|
for word, idx in word2idx.items():
|
|
if word not in wemb.stoi:
|
|
word = word.replace('-', '').replace('.', '').replace("'", '')
|
|
if '/' in word:
|
|
word = word.split('/')[0]
|
|
if word in wemb.stoi:
|
|
self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
|
|
else:
|
|
missing_words.append(word)
|
|
print('Words: {}/{} found in vocabulary; {} words missing'.format(
|
|
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
|
|
|
|
def forward(self, x, lengths):
|
|
"""Handles variable size captions
|
|
"""
|
|
# Embed word ids to vectors
|
|
#print(x, x.shape, lengths, len(lengths))
|
|
x = self.embed(x)
|
|
packed = pack_padded_sequence(x, lengths, batch_first=True)
|
|
if torch.cuda.device_count() > 1:
|
|
self.rnn.flatten_parameters()
|
|
# Forward propagate RNN
|
|
out, _ = self.rnn(packed)
|
|
#print(out.dtype, out.shape)
|
|
#print("---")
|
|
|
|
# Reshape *final* output to (batch_size, hidden_size)
|
|
padded = pad_packed_sequence(out, batch_first=True)
|
|
cap_emb, cap_len = padded
|
|
|
|
if self.use_bi_gru:
|
|
cap_emb = (cap_emb[:, :, :int(cap_emb.size(2) / 2)] + cap_emb[:, :, int(cap_emb.size(2) / 2):]) / 2
|
|
#cap_emb_mean = torch.mean(cap_emb, 1)
|
|
cap_emb_mean = torch.max(cap_emb, 1)[0]
|
|
# normalization in the joint embedding space
|
|
if not self.no_txtnorm:
|
|
cap_emb = l2norm(cap_emb, dim=-1)
|
|
cap_emb_mean = l2norm(cap_emb_mean, dim=1)
|
|
#print(cap_emb.shape, cap_emb_mean.shape)
|
|
return cap_emb, cap_len, cap_emb_mean
|
|
|
|
|
|
''' Visual self-attention module '''
|
|
|
|
|
|
class V_single_modal_atten(nn.Module):
|
|
"""
|
|
Single Visual Modal Attention Network.
|
|
"""
|
|
|
|
def __init__(self, image_dim, embed_dim, dropout_rate=0.4, img_region_num=36):
|
|
"""
|
|
param image_dim: dim of visual feature
|
|
param embed_dim: dim of embedding space
|
|
"""
|
|
super(V_single_modal_atten, self).__init__()
|
|
|
|
self.fc1 = nn.Linear(image_dim, embed_dim) # embed visual feature to common space
|
|
|
|
self.fc2 = nn.Linear(image_dim, embed_dim) # embed memory to common space
|
|
self.fc2_2 = nn.Linear(embed_dim, embed_dim)
|
|
|
|
self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights
|
|
self.fc4 = nn.Linear(image_dim, embed_dim) # embed attentive feature to common space
|
|
|
|
self.embedding_1 = nn.Sequential(self.fc1, nn.BatchNorm1d(img_region_num), nn.Tanh(), nn.Dropout(dropout_rate))
|
|
self.embedding_2 = nn.Sequential(self.fc2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate))
|
|
self.embedding_2_2 = nn.Sequential(self.fc2_2, nn.BatchNorm1d(embed_dim), nn.Tanh(), nn.Dropout(dropout_rate))
|
|
self.embedding_3 = nn.Sequential(self.fc3)
|
|
|
|
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
|
|
|
def forward(self, v_t, m_v):
|
|
"""
|
|
Forward propagation.
|
|
:param v_t: encoded images, shape: (batch_size, num_regions, image_dim)
|
|
:param m_v: previous visual memory, shape: (batch_size, image_dim)
|
|
:return: attention weighted encoding, weights
|
|
"""
|
|
W_v = self.embedding_1(v_t)
|
|
|
|
if m_v.size()[-1] == v_t.size()[-1]:
|
|
W_v_m = self.embedding_2(m_v)
|
|
else:
|
|
W_v_m = self.embedding_2_2(m_v)
|
|
|
|
W_v_m = W_v_m.unsqueeze(1).repeat(1, W_v.size()[1], 1)
|
|
|
|
h_v = W_v.mul(W_v_m)
|
|
|
|
a_v = self.embedding_3(h_v)
|
|
a_v = a_v.squeeze(2)
|
|
weights = self.softmax(a_v)
|
|
|
|
v_att = ((weights.unsqueeze(2) * v_t)).sum(dim=1)
|
|
|
|
# l2 norm
|
|
v_att = l2norm(v_att, -1)
|
|
|
|
return v_att, weights
|
|
|
|
|
|
class T_single_modal_atten(nn.Module):
|
|
"""
|
|
Single Textual Modal Attention Network.
|
|
"""
|
|
|
|
def __init__(self, embed_dim, dropout_rate=0.4):
|
|
"""
|
|
param image_dim: dim of visual feature
|
|
param embed_dim: dim of embedding space
|
|
"""
|
|
super(T_single_modal_atten, self).__init__()
|
|
|
|
self.fc1 = nn.Linear(embed_dim, embed_dim) # embed visual feature to common space
|
|
self.fc2 = nn.Linear(embed_dim, embed_dim) # embed memory to common space
|
|
self.fc3 = nn.Linear(embed_dim, 1) # turn fusion_info to attention weights
|
|
|
|
self.embedding_1 = nn.Sequential(self.fc1, nn.Tanh(), nn.Dropout(dropout_rate))
|
|
self.embedding_2 = nn.Sequential(self.fc2, nn.Tanh(), nn.Dropout(dropout_rate))
|
|
self.embedding_3 = nn.Sequential(self.fc3)
|
|
|
|
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
|
|
|
def forward(self, u_t, m_u):
|
|
"""
|
|
Forward propagation.
|
|
:param v_t: encoded images, shape: (batch_size, num_regions, image_dim)
|
|
:param m_v: previous visual memory, shape: (batch_size, image_dim)
|
|
:return: attention weighted encoding, weights
|
|
"""
|
|
W_u = self.embedding_1(u_t)
|
|
|
|
W_u_m = self.embedding_2(m_u)
|
|
W_u_m = W_u_m.unsqueeze(1).repeat(1, W_u.size()[1], 1)
|
|
|
|
h_u = W_u.mul(W_u_m)
|
|
|
|
a_u = self.embedding_3(h_u)
|
|
a_u = a_u.squeeze(2)
|
|
weights = self.softmax(a_u)
|
|
|
|
u_att = ((weights.unsqueeze(2) * u_t)).sum(dim=1)
|
|
|
|
# l2 norm
|
|
u_att = l2norm(u_att, -1)
|
|
|
|
return u_att, weights
|
|
|
|
|
|
class ContrastiveLoss(nn.Module):
|
|
"""
|
|
Compute contrastive loss
|
|
"""
|
|
|
|
def __init__(self, margin=0):
|
|
super(ContrastiveLoss, self).__init__()
|
|
self.margin = margin
|
|
|
|
def forward(self, scores):
|
|
# compute image-sentence score matrix
|
|
|
|
diagonal = scores.diag().view(-1, 1)
|
|
d1 = diagonal.expand_as(scores)
|
|
d2 = diagonal.t().expand_as(scores)
|
|
|
|
# compare every diagonal score to scores in its column
|
|
# caption retrieval
|
|
cost_s = (self.margin + scores - d1).clamp(min=0)
|
|
# compare every diagonal score to scores in its row
|
|
# image retrieval
|
|
cost_im = (self.margin + scores - d2).clamp(min=0)
|
|
|
|
# clear diagonals
|
|
mask = torch.eye(scores.size(0)) > .5
|
|
I = Variable(mask)
|
|
if torch.cuda.is_available():
|
|
I = I.cuda()
|
|
cost_s = cost_s.masked_fill_(I, 0)
|
|
cost_im = cost_im.masked_fill_(I, 0)
|
|
|
|
# keep the maximum violating negative for each query
|
|
cost_s = cost_s.max(1)[0]
|
|
cost_im = cost_im.max(0)[0]
|
|
|
|
return cost_s.sum() + cost_im.sum()
|
|
|
|
|
|
class SCAN3(nn.Module):
|
|
"""
|
|
Stacked Cross Attention Network (SCAN) model
|
|
"""
|
|
|
|
def __init__(self, word2idx, opt):
|
|
super(SCAN3, self).__init__()
|
|
# Build Models
|
|
self.grad_clip = opt.grad_clip
|
|
self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size,
|
|
precomp_enc_type=opt.precomp_enc_type,
|
|
no_imgnorm=opt.no_imgnorm)
|
|
self.txt_enc = encoder_text(word2idx, opt.vocab_size, opt.word_dim,
|
|
opt.embed_size, opt.num_layers,
|
|
use_bi_gru=True,
|
|
no_txtnorm=opt.no_txtnorm)
|
|
#self.txt_enc = EncoderTextBERT(opt, post_transformer_layers=opt.text_model_layers)
|
|
|
|
self.V_self_atten_enhance = V_single_modal_atten(opt.embed_size, opt.embed_size)
|
|
self.T_self_atten_enhance = T_single_modal_atten(opt.embed_size)
|
|
|
|
self.opt = opt
|
|
self.Eiters = 0
|
|
|
|
def forward_emb(self, images, captions, lengths):
|
|
"""Compute the image and caption embeddings
|
|
"""
|
|
# Forward
|
|
img_emb, img_mean = self.img_enc(images)
|
|
#print(img_emb.shape,img_mean.shape)
|
|
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
|
|
|
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
|
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
|
return img_emb, img_mean, cap_emb, cap_lens, cap_mean
|
|
|
|
def txt_emb(self, captions, lengths):
|
|
"""Compute the caption embeddings
|
|
"""
|
|
# Forward
|
|
cap_emb, cap_lens, cap_mean = self.txt_enc(captions, lengths)
|
|
|
|
cap_mean, _ = self.T_self_atten_enhance(cap_emb, cap_mean)
|
|
return cap_emb, cap_lens, cap_mean
|
|
|
|
def image_emb(self, images):
|
|
"""Compute the image embeddings
|
|
"""
|
|
# Forward
|
|
img_emb, img_mean = self.img_enc(images)
|
|
img_mean, _ = self.V_self_atten_enhance(img_emb, img_mean)
|
|
return img_emb, img_mean
|
|
|
|
def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs):
|
|
"""Compute the loss given pairs of image and caption embeddings
|
|
"""
|
|
scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean)
|
|
|
|
return scores
|
|
|
|
def forward(self, images, captions, lengths, ids=None, *args):
|
|
# compute the embeddings
|
|
lengths = lengths.cpu().numpy().tolist()
|
|
img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, captions, lengths)
|
|
scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean)
|
|
return scores
|
|
|
|
def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean):
|
|
similarities = []
|
|
n_image = images.size(0)
|
|
n_caption = captions.size(0)
|
|
g_sims = cap_mean.mm(img_mean.t())
|
|
now = time.time()
|
|
for i in range(n_caption):
|
|
# Get the i-th text description
|
|
n_word = cap_lens[i]
|
|
g_sim = g_sims[i]
|
|
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
|
|
# --> (n_image, n_word, d)
|
|
cap_i_expand = cap_i.repeat(n_image, 1, 1)
|
|
|
|
# t2i process
|
|
# weiContext: (n_image, n_word, d)
|
|
weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt)
|
|
t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
|
|
t2i_sim = t2i_sim.mean(dim=1, keepdim=True)
|
|
|
|
# i2t process
|
|
# weiContext: (n_image, n_word, d)
|
|
weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt)
|
|
i2t_sim = cosine_similarity(images, weiContext, dim=2)
|
|
i2t_sim = i2t_sim.mean(dim=1, keepdim=True)
|
|
|
|
# Overall similarity for image and text
|
|
sim = t2i_sim + i2t_sim
|
|
|
|
similarities.append(sim)
|
|
|
|
# (n_image, n_caption)
|
|
similarities = torch.cat(similarities, 1)
|
|
|
|
if self.training:
|
|
similarities = similarities.transpose(0, 1)
|
|
#print('Time:{:.4f}'.format(time.time() - now))
|
|
return similarities
|