import torch import torch.nn as nn import torch.distributed as dist import torch.nn.init import torchvision.models as models from torch.autograd import Variable from torch.nn.utils.clip_grad import clip_grad_norm from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence import numpy as np from collections import OrderedDict from model.utils import * import copy import ast #from .mca import SA,SGA from .pyramid_vig import DeepGCN, pvig_ti_224_gelu from .GAT import GAT, GATopt, GAT_T from vocab import deserialize_vocab def l2norm(X, dim = -1, eps=1e-8): """L2-normalize columns of X """ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps X = torch.div(X, norm) return X def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False): """Returns cosine similarity between x1 and x2, computed along dim.""" w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim) w1 = torch.norm(x1, 2, dim, keepdim=keep_dim) w2 = torch.norm(x2, 2, dim, keepdim=keep_dim) if keep_dim: return w12 / (w1 * w2).clamp(min=eps) else: return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1) def func_attention(query, context, g_sim, opt, eps=1e-8): """ query: (batch, queryL, d) context: (batch, sourceL, d) opt: parameters """ batch_size, queryL, sourceL = context.size( 0), query.size(1), context.size(1) # Step 1: preassign attention # --> (batch, d, queryL) queryT = torch.transpose(query, 1, 2) # (batch, sourceL, d)(batch, d, queryL) attn = torch.bmm(context, queryT) attn = nn.LeakyReLU(0.1)(attn) attn = l2norm(attn, 2) # --> (batch, queryL, sourceL) attn = torch.transpose(attn, 1, 2).contiguous() # --> (batch*queryL, sourceL) attn = attn.view(batch_size * queryL, sourceL) attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax) # --> (batch, queryL, sourceL) attn = attn.view(batch_size, queryL, sourceL) # Step 2: identify irrelevant fragments # Learning an indicator function H, one for relevant, zero for irrelevant if opt.correct_type == 'equal': re_attn = correct_equal(attn, query, context, sourceL, g_sim) elif opt.correct_type == 'prob': re_attn = correct_prob(attn, query, context, sourceL, g_sim) # --> (batch, d, sourceL) contextT = torch.transpose(context, 1, 2) # --> (batch, sourceL, queryL) re_attnT = torch.transpose(re_attn, 1, 2).contiguous() # (batch x d x sourceL)(batch x sourceL x queryL) # --> (batch, d, queryL) weightedContext = torch.bmm(contextT, re_attnT) # --> (batch, queryL, d) weightedContext = torch.transpose(weightedContext, 1, 2) if torch.isnan(weightedContext).any(): print('ddd') return weightedContext, re_attn def correct_equal(attn, query, context, sourceL, g_sim): """ consider the confidence g(x) for each fragment as equal sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj attn: (batch, queryL, sourceL) """ # GCU process d = g_sim - 0.3 d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d) re_attn = d.unsqueeze(1).unsqueeze(2) * attn attn_sum = torch.sum(re_attn, dim=-1, keepdim=True) re_attn = re_attn / attn_sum cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True) cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1) re_attn1 = focal_equal(re_attn, query, context, sourceL) # LCU process cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True) cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos) delta = cos - cos1 delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta) re_attn2 = delta * re_attn1 attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True) re_attn2 = re_attn2 / attn_sum re_attn2 = focal_equal(re_attn2, query, context, sourceL) return re_attn2 def focal_equal(attn, query, context, sourceL): funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True) fattn = torch.where(funcF > 0, torch.ones_like(attn), torch.zeros_like(attn)) # Step 3: reassign attention tmp_attn = fattn * attn attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True) re_attn = tmp_attn / attn_sum return re_attn def correct_prob(attn, query, context, sourceL, g_sim): """ consider the confidence g(x) for each fragment as the sqrt of their similarity probability to the query fragment sigma_{j} (xi - xj)gj = sigma_{j} xi*gj - sigma_{j} xj*gj attn: (batch, queryL, sourceL) """ # GCU process d = g_sim - 0.3 d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d) re_attn = d.unsqueeze(1).unsqueeze(2) * attn attn_sum = torch.sum(re_attn, dim=-1, keepdim=True) re_attn = re_attn / attn_sum cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True) cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1) re_attn1 = focal_prob(re_attn, query, context, sourceL) # LCU process cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True) cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos) delta = cos - cos1 delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta) re_attn2 = delta * re_attn1 attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True) re_attn2 = re_attn2 / attn_sum re_attn2 = focal_prob(re_attn2, query, context, sourceL) return re_attn2 def focal_prob(attn, query, context, sourceL): batch_size, queryL, sourceL = context.size( 0), query.size(1), context.size(1) # -> (batch, queryL, sourceL, 1) xi = attn.unsqueeze(-1).contiguous() # -> (batch, queryL, 1, sourceL) xj = attn.unsqueeze(2).contiguous() # -> (batch, queryL, 1, sourceL) xj_confi = torch.sqrt(xj) xi = xi.view(batch_size * queryL, sourceL, 1) xj = xj.view(batch_size * queryL, 1, sourceL) xj_confi = xj_confi.view(batch_size * queryL, 1, sourceL) # -> (batch*queryL, sourceL, sourceL) term1 = torch.bmm(xi, xj_confi).clamp(min=1e-8) term2 = xj * xj_confi funcF = torch.sum(term1 - term2, dim=-1) # -> (batch*queryL, sourceL) funcF = funcF.view(batch_size, queryL, sourceL) fattn = torch.where(funcF > 0, torch.ones_like(attn), torch.zeros_like(attn)) # Step 3: reassign attention tmp_attn = fattn * attn attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True) re_attn = tmp_attn / attn_sum if torch.isnan(re_attn).any(): print("ddd") return re_attn # cross attention for MGAN class CrossAttention(nn.Module): def __init__(self, opt={}): super(CrossAttention, self).__init__() self.att_type = opt['cross_attention']['att_type'] dim = opt['embed']['embed_dim'] channel_size = 512 self.visual_conv = nn.Conv2d(in_channels=384, out_channels=channel_size, kernel_size=1, stride=1) # visual attention self.visual_attention = GAT(GATopt(512, 1)) # self.bn_out = nn.BatchNorm1d(512) # self.dropout_out = nn.Dropout(0.5) if self.att_type == "soft_att": self.cross_attention = nn.Sequential( nn.Linear(dim, dim), nn.Sigmoid() ) elif self.att_type == "fusion_att": self.cross_attention_fc1 = nn.Sequential( nn.Linear(2*dim, dim), nn.Sigmoid() ) self.cross_attention_fc2 = nn.Sequential( nn.Linear(2*dim, dim), ) self.cross_attention = lambda x:self.cross_attention_fc1(x)*self.cross_attention_fc2(x) elif self.att_type == "similarity_att": self.fc_visual = nn.Sequential( nn.Linear(dim, dim), ) self.fc_text = nn.Sequential( nn.Linear(dim, dim), ) else: raise Exception def forward(self, visual, text): batch_v = visual.shape[0] batch_t = text.shape[0] visual_feature = self.visual_conv(visual) visual_feature = self.visual_attention(visual_feature) # [b x 512 x 1 x 1] visual_feature = F.adaptive_avg_pool2d(visual_feature, 1) # [b x 512] visual_feature = visual_feature.squeeze(-1).squeeze(-1) if self.att_type == "soft_att": visual_gate = self.cross_attention(visual_feature) # mm visual_gate = visual_gate.unsqueeze(dim=1).expand(-1, batch_t, -1) text = text.unsqueeze(dim=0).expand(batch_v, -1, -1) return visual_gate*text elif self.att_type == "fusion_att": visual = visual_feature.unsqueeze(dim=1).expand(-1, batch_t, -1) text = text.unsqueeze(dim=0).expand(batch_v, -1, -1) fusion_vec = torch.cat([visual,text], dim=-1) return self.cross_attention(fusion_vec) elif self.att_type == "similarity_att": visual = self.fc_visual(visual_feature) text = self.fc_text(text) visual = visual.unsqueeze(dim=1).expand(-1, batch_t, -1) text = text.unsqueeze(dim=0).expand(batch_v, -1, -1) sims = visual * text text_feature = F.sigmoid(sims) * text return text_feature # return l2norm(text_feature, -1) class MGAM(nn.Module): def __init__(self, opt = {}): super(MGAM, self).__init__() # extract value channel_size = 256 # sub sample self.LF_conv = nn.Conv2d(in_channels=240, out_channels=channel_size, kernel_size=2, stride=2) # 240 self.HF_conv = nn.Conv2d(in_channels=384, out_channels=channel_size, kernel_size=1, stride=1) # 512 # visual attention self.concat_attention = GAT(GATopt(512, 1)) def forward(self, lower_feature, higher_feature, solo_feature): # b x channel_size x 16 x 16 lower_feature = self.LF_conv(lower_feature) higher_feature = self.HF_conv(higher_feature) # concat -> [b x 512 x 7 x 7] concat_feature = torch.cat([lower_feature, higher_feature], dim=1) # residual -> [b x 512 x 7 x 7] concat_feature = higher_feature.mean(dim=1,keepdim=True).expand_as(concat_feature) + concat_feature # attention - >[b x 512 x 7 x 7] attent_feature = self.concat_attention(concat_feature) # [b x 512 x 1 x 1] attent_feature = F.adaptive_avg_pool2d(attent_feature, 1) # [b x 512] attent_feature = attent_feature.squeeze(-1).squeeze(-1) # solo attention solo_att = torch.sigmoid(attent_feature) final_feature = solo_feature * solo_att return l2norm(final_feature, -1) class EncoderText(nn.Module): def __init__(self, word2idx, vocab_size, word_dim = 300, embed_size = 512, num_layers = 1, use_bi_gru=True, no_txtnorm=False): super(EncoderText, self).__init__() self.embed_size = embed_size self.no_txtnorm = no_txtnorm # word embedding self.embed = nn.Embedding(vocab_size, word_dim) # nn.Embedding.from_pretrained() # caption embedding self.use_bi_gru = use_bi_gru self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru) # self.init_weights() self.init_weights(word2idx) def init_weights(self): self.embed.weight.data.uniform_(-0.1, 0.1) def init_weights(self, word2idx): embeddings_glove = np.load('/home/wzm/crossmodal/embs_npa.npy') vocab_glove = deserialize_vocab('/home/wzm/crossmodal/GAC/vocab/glove_precomp_vocab.json') glove_w2i = vocab_glove.word2idx vocab, embeddings = [], [] # quick-and-dirty trick to improve word-hit rate missing_words = [] for word, idx in word2idx.items(): # print('idx : ', idx) if word in glove_w2i: self.embed.weight.data[idx] = torch.FloatTensor(embeddings_glove[glove_w2i[word]]) # embeddings.append(embeddings_glove[glove_w2i[word]]) else: missing_words.append(word) # print(missing_words) print('Words: {}/{} found in vocabulary; {} words missing'.format( len(word2idx) - len(missing_words), len(word2idx), len(missing_words))) for param in self.embed.parameters(): param.requires_grad = False # vocab_npa = np.array(vocab) # embs_npa = np.array(embeddings) # self.embed = torch.nn.Embedding.from_pretrained(torch.from_numpy(embs_npa).float(), freeze=True) # self.embed.weight.data.uniform_(-0.1, 0.1) def forward(self, x, lengths): """Handles variable size captions """ # Embed word ids to vectors x = self.embed(x) packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) # Forward propagate RNN out, _ = self.rnn(packed) # Reshape *final* output to (batch_size, hidden_size) padded = pad_packed_sequence(out, batch_first=True) cap_emb, cap_len = padded if self.use_bi_gru: cap_emb = (cap_emb[:,:,:cap_emb.size(2)//2] + cap_emb[:,:,cap_emb.size(2)//2:])/2 # normalization in the joint embedding space if not self.no_txtnorm: cap_emb = l2norm(cap_emb, dim=-1) return cap_emb, cap_len class BaseModel(nn.Module): def __init__(self, vocab, opt={}): super(BaseModel, self).__init__() # img feature # self.extract_feature = ExtractFeature(opt = opt) self.extract_feature = pvig_ti_224_gelu() self.HF_conv = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1, stride=1) # 512 # vsa feature self.mgam = MGAM(opt = opt) # text feature # self.text_feature = Skipthoughts_Embedding_Module( # vocab= vocab_words, # opt = opt # ) self.text_feature = EncoderText(word2idx = vocab.word2idx, vocab_size = len(vocab)) self.gat_cap = GAT_T(GATopt(512, 1)) self.cross_attention_s = CrossAttention(opt = opt) self.Eiters = 0 def forward(self, img, captions, lengths=None): # extract features lower_feature, higher_feature, solo_feature = self.extract_feature(img) # mvsa featrues visual_feature = self.mgam(lower_feature, higher_feature, solo_feature) # visual_feature = solo_feature # text features # print("captions shape :", captions.shape) # text_feature = self.text_feature(captions) text_feature, cap_lengths = self.text_feature(captions, lengths) text_feature = self.gat_cap(text_feature) # print("text_feature shape :", text_feature.shape) text_feature = l2norm(torch.mean(text_feature, dim=1)) # print("text_feature shape :", text_feature.shape) # VGMF # text_feature = self.aff(higher_feature, text_feature) # dual_text = self.cross_attention_s(higher_feature, text_feature) # dual_text = self.cross_attention_s(solo_feature, text_feature) # Ft = text_feature # print("Ft size : ", Ft.shape) # sim dual path # dual_visual = visual_feature.unsqueeze(dim=1).expand(-1, dual_text.shape[1], -1) # print("mvsa_feature size : ", mvsa_feature.shape) # sims = cosine_similarity(dual_visual, dual_text) sims = cosine_sim(visual_feature, text_feature) return sims def factory(opt, vocab, cuda=True, data_parallel=True): opt = copy.copy(opt) model = BaseModel(vocab, opt) # print(model) if data_parallel: model = nn.DataParallel(model).cuda() if not cuda: raise ValueError if cuda: model.cuda() return model