import torch import torch.nn as nn import torch.nn.init import numpy as np import copy from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from vocab import deserialize_vocab from model.utils import * from .pyramid_vig import pvig_ti_224_gelu from .GAT import GAT, GATopt, GAT_T def l2norm(X, dim = -1, eps=1e-8): """L2-normalize columns of X """ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps X = torch.div(X, norm) return X def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False): """Returns cosine similarity between x1 and x2, computed along dim.""" w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim) w1 = torch.norm(x1, 2, dim, keepdim=keep_dim) w2 = torch.norm(x2, 2, dim, keepdim=keep_dim) if keep_dim: return w12 / (w1 * w2).clamp(min=eps) else: return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1) def func_attention(query, context, g_sim, opt, eps=1e-8): """ query: (batch, queryL, d) context: (batch, sourceL, d) opt: parameters """ batch_size, queryL, sourceL = context.size( 0), query.size(1), context.size(1) # Step 1: preassign attention # --> (batch, d, queryL) queryT = torch.transpose(query, 1, 2) # (batch, sourceL, d)(batch, d, queryL) attn = torch.bmm(context, queryT) attn = nn.LeakyReLU(0.1)(attn) attn = l2norm(attn, 2) # --> (batch, queryL, sourceL) attn = torch.transpose(attn, 1, 2).contiguous() # --> (batch*queryL, sourceL) attn = attn.view(batch_size * queryL, sourceL) # attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax) attn = nn.Softmax(dim=1)(attn * 20) # --> (batch, queryL, sourceL) attn = attn.view(batch_size, queryL, sourceL) # Step 2: use g_sim to correct local_sim re_attn = correct_equal(attn, query, context, sourceL, g_sim, opt['cross_attention']['threshold_p']) # Step 3: identify irrelevant fragments re_attn = focal_equal(re_attn, query, context, sourceL, opt['cross_attention']['threshold_q']) # Step 4: get final local feature # --> (batch, d, sourceL) contextT = torch.transpose(context, 1, 2) # --> (batch, sourceL, queryL) re_attnT = torch.transpose(re_attn, 1, 2).contiguous() # (batch x d x sourceL)(batch x sourceL x queryL) # --> (batch, d, queryL) weightedContext = torch.bmm(contextT, re_attnT) # --> (batch, queryL, d) weightedContext = torch.transpose(weightedContext, 1, 2) return weightedContext, re_attn def correct_equal(attn, query, context, sourceL, g_sim, threshold_p): """ consider the confidence g(x) for each fragment as equal sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj attn: (batch, queryL, sourceL) """ # GCU process d = g_sim - threshold_p d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d) re_attn = d.unsqueeze(1).unsqueeze(2) * attn attn_sum = torch.sum(re_attn, dim=-1, keepdim=True) re_attn = re_attn / attn_sum return re_attn def focal_equal(attn, query, context, sourceL, threshold_q): #TODO: i try to choose max sim for hard coding max_attn = torch.max(attn, dim = -1, keepdim=True)[0] # funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True) funcF = max_attn - attn fattn = torch.where(funcF < threshold_q, torch.ones_like(attn), torch.zeros_like(attn)) # Step 3: reassign attention tmp_attn = fattn * attn attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True) re_attn = tmp_attn / attn_sum return re_attn class MGAM(nn.Module): def __init__(self, opt = {}): super(MGAM, self).__init__() # extract value channel_size = 256 # sub sample self.LF_conv = nn.Conv2d(in_channels=240, out_channels=channel_size, kernel_size=2, stride=2) # 240 self.HF_conv = nn.Conv2d(in_channels=384, out_channels=channel_size, kernel_size=1, stride=1) # 512 # visual attention self.concat_attention = GAT(GATopt(512, 1)) def forward(self, lower_feature, higher_feature, solo_feature): # b x channel_size x 16 x 16 lower_feature = self.LF_conv(lower_feature) higher_feature = self.HF_conv(higher_feature) # concat -> [b x 512 x 7 x 7] concat_feature = torch.cat([lower_feature, higher_feature], dim=1) # residual -> [b x 512 x 7 x 7] concat_feature = higher_feature.mean(dim=1,keepdim=True).expand_as(concat_feature) + concat_feature # attention - >[b x 512 x 7 x 7] attent_feature = self.concat_attention(concat_feature) visual_embs = attent_feature.view(attent_feature.size(0), 512, -1).transpose(1, 2) # [b x 512 x 1 x 1] attent_feature = F.adaptive_avg_pool2d(attent_feature, 1) # [b x 512] attent_feature = attent_feature.squeeze(-1).squeeze(-1) # solo attention solo_att = torch.sigmoid(attent_feature) final_feature = solo_feature * solo_att return visual_embs, l2norm(final_feature, -1) class EncoderText(nn.Module): def __init__(self, word2idx, vocab_size, word_dim = 300, embed_size = 512, num_layers = 1, use_bi_gru=True, no_txtnorm=False): super(EncoderText, self).__init__() self.embed_size = embed_size self.no_txtnorm = no_txtnorm # word embedding self.embed = nn.Embedding(vocab_size, word_dim) # nn.Embedding.from_pretrained() # caption embedding self.use_bi_gru = use_bi_gru self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru) # self.init_weights() self.init_weights(word2idx) def init_weights(self): self.embed.weight.data.uniform_(-0.1, 0.1) def init_weights(self, word2idx): embeddings_glove = np.load('/home/wzm/embs_npa.npy') vocab_glove = deserialize_vocab('/home/wzm/GAC/vocab/glove_precomp_vocab.json') glove_w2i = vocab_glove.word2idx vocab, embeddings = [], [] # quick-and-dirty trick to improve word-hit rate missing_words = [] for word, idx in word2idx.items(): # print('idx : ', idx) if word in glove_w2i: self.embed.weight.data[idx] = torch.FloatTensor(embeddings_glove[glove_w2i[word]]) # embeddings.append(embeddings_glove[glove_w2i[word]]) else: missing_words.append(word) # print(missing_words) print('Words: {}/{} found in vocabulary; {} words missing'.format( len(word2idx) - len(missing_words), len(word2idx), len(missing_words))) for param in self.embed.parameters(): param.requires_grad = False # vocab_npa = np.array(vocab) # embs_npa = np.array(embeddings) # self.embed = torch.nn.Embedding.from_pretrained(torch.from_numpy(embs_npa).float(), freeze=True) # self.embed.weight.data.uniform_(-0.1, 0.1) def forward(self, x, lengths): """Handles variable size captions """ # Embed word ids to vectors x = self.embed(x) packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) # Forward propagate RNN out, _ = self.rnn(packed) # Reshape *final* output to (batch_size, hidden_size) padded = pad_packed_sequence(out, batch_first=True) cap_emb, cap_len = padded if self.use_bi_gru: cap_emb = (cap_emb[:,:,:cap_emb.size(2)//2] + cap_emb[:,:,cap_emb.size(2)//2:])/2 # normalization in the joint embedding space if not self.no_txtnorm: cap_emb = l2norm(cap_emb, dim=-1) return cap_emb, cap_len class BaseModel(nn.Module): def __init__(self, vocab, opt={}): super(BaseModel, self).__init__() self.opt = opt # --------------- img feature ------------ # self.extract_feature = ExtractFeature(opt = opt) self.extract_feature = pvig_ti_224_gelu() self.HF_conv = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1, stride=1) # 512 # vsa feature self.mgam = MGAM(opt = opt) # --------------- text feature ----------- self.text_feature = EncoderText(word2idx = vocab.word2idx, vocab_size = len(vocab)) self.gat_cap = GAT_T(GATopt(512, 1)) self.Eiters = 0 def forward_emb(self, images, captions, lengths=None): # extract features lower_feature, higher_feature, solo_feature = self.extract_feature(images) # mvsa featrues img_embs, img_solo = self.mgam(lower_feature, higher_feature, solo_feature) # img_embs, img_solo = [], solo_feature # visual_feature = solo_feature # text features cap_embs, cap_lens = self.text_feature(captions, lengths) cap_embs = self.gat_cap(cap_embs) # print("cap_embs shape :", cap_embs.shape) cap_solo = l2norm(torch.mean(cap_embs, dim=1)) # print("cap_solo shape :", cap_solo.shape) return img_embs, img_solo, cap_embs, cap_lens, cap_solo def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs): """Compute the loss given pairs of image and caption embeddings """ scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean) # # # TODO: add global sims # sims = cosine_sim(img_mean, cap_mean) # return scores + sims return scores # ------- no cross attention ---------- # sims = cosine_sim(img_mean, cap_mean) # return sims def forward(self, images, captions, lengths, ids=None, *args): # compute the embeddings # lengths = lengths.cpu().numpy().tolist() img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, captions, lengths) scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean) return scores def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean): similarities = [] n_image = images.size(0) n_caption = captions.size(0) g_sims = cap_mean.mm(img_mean.t()) for i in range(n_caption): # Get the i-th text description n_word = cap_lens[i] g_sim = g_sims[i] cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous() # --> (n_image, n_word, d) cap_i_expand = cap_i.repeat(n_image, 1, 1) # t2i process # weiContext: (n_image, n_word, d) weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt) t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2) t2i_sim = t2i_sim.mean(dim=1, keepdim=True) # i2t process # weiContext: (n_image, n_word, d) weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt) i2t_sim = cosine_similarity(images, weiContext, dim=2) i2t_sim = i2t_sim.mean(dim=1, keepdim=True) # Overall similarity for image and text sim = t2i_sim + i2t_sim similarities.append(sim) # (n_image, n_caption) similarities = torch.cat(similarities, 1) if self.training: similarities = similarities.transpose(0, 1) return similarities def factory(opt, vocab, cuda=True, data_parallel=True): opt = copy.copy(opt) model = BaseModel(vocab, opt) # print(model) if data_parallel: model = nn.DataParallel(model).cuda() if not cuda: raise ValueError if cuda: model.cuda() return model