462 lines
16 KiB
Python
462 lines
16 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.distributed as dist
|
||
|
import torch.nn.init
|
||
|
import torchvision.models as models
|
||
|
from torch.autograd import Variable
|
||
|
from torch.nn.utils.clip_grad import clip_grad_norm
|
||
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||
|
import numpy as np
|
||
|
from collections import OrderedDict
|
||
|
from model.utils import *
|
||
|
import copy
|
||
|
import ast
|
||
|
#from .mca import SA,SGA
|
||
|
from .pyramid_vig import DeepGCN, pvig_ti_224_gelu
|
||
|
from .GAT import GAT, GATopt, GAT_T
|
||
|
|
||
|
from vocab import deserialize_vocab
|
||
|
|
||
|
def l2norm(X, dim = -1, eps=1e-8):
|
||
|
"""L2-normalize columns of X
|
||
|
"""
|
||
|
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
|
||
|
X = torch.div(X, norm)
|
||
|
return X
|
||
|
|
||
|
def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False):
|
||
|
"""Returns cosine similarity between x1 and x2, computed along dim."""
|
||
|
w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim)
|
||
|
w1 = torch.norm(x1, 2, dim, keepdim=keep_dim)
|
||
|
w2 = torch.norm(x2, 2, dim, keepdim=keep_dim)
|
||
|
if keep_dim:
|
||
|
return w12 / (w1 * w2).clamp(min=eps)
|
||
|
else:
|
||
|
return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1)
|
||
|
|
||
|
def func_attention(query, context, g_sim, opt, eps=1e-8):
|
||
|
"""
|
||
|
query: (batch, queryL, d)
|
||
|
context: (batch, sourceL, d)
|
||
|
opt: parameters
|
||
|
"""
|
||
|
batch_size, queryL, sourceL = context.size(
|
||
|
0), query.size(1), context.size(1)
|
||
|
|
||
|
# Step 1: preassign attention
|
||
|
# --> (batch, d, queryL)
|
||
|
queryT = torch.transpose(query, 1, 2)
|
||
|
|
||
|
# (batch, sourceL, d)(batch, d, queryL)
|
||
|
attn = torch.bmm(context, queryT)
|
||
|
attn = nn.LeakyReLU(0.1)(attn)
|
||
|
attn = l2norm(attn, 2)
|
||
|
|
||
|
# --> (batch, queryL, sourceL)
|
||
|
attn = torch.transpose(attn, 1, 2).contiguous()
|
||
|
# --> (batch*queryL, sourceL)
|
||
|
attn = attn.view(batch_size * queryL, sourceL)
|
||
|
attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax)
|
||
|
# --> (batch, queryL, sourceL)
|
||
|
attn = attn.view(batch_size, queryL, sourceL)
|
||
|
|
||
|
# Step 2: identify irrelevant fragments
|
||
|
# Learning an indicator function H, one for relevant, zero for irrelevant
|
||
|
if opt.correct_type == 'equal':
|
||
|
re_attn = correct_equal(attn, query, context, sourceL, g_sim)
|
||
|
elif opt.correct_type == 'prob':
|
||
|
re_attn = correct_prob(attn, query, context, sourceL, g_sim)
|
||
|
# --> (batch, d, sourceL)
|
||
|
contextT = torch.transpose(context, 1, 2)
|
||
|
# --> (batch, sourceL, queryL)
|
||
|
re_attnT = torch.transpose(re_attn, 1, 2).contiguous()
|
||
|
# (batch x d x sourceL)(batch x sourceL x queryL)
|
||
|
# --> (batch, d, queryL)
|
||
|
weightedContext = torch.bmm(contextT, re_attnT)
|
||
|
|
||
|
# --> (batch, queryL, d)
|
||
|
weightedContext = torch.transpose(weightedContext, 1, 2)
|
||
|
|
||
|
if torch.isnan(weightedContext).any():
|
||
|
print('ddd')
|
||
|
return weightedContext, re_attn
|
||
|
|
||
|
|
||
|
def correct_equal(attn, query, context, sourceL, g_sim):
|
||
|
"""
|
||
|
consider the confidence g(x) for each fragment as equal
|
||
|
sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj
|
||
|
attn: (batch, queryL, sourceL)
|
||
|
"""
|
||
|
# GCU process
|
||
|
d = g_sim - 0.3
|
||
|
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||
|
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||
|
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||
|
re_attn = re_attn / attn_sum
|
||
|
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
||
|
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
||
|
re_attn1 = focal_equal(re_attn, query, context, sourceL)
|
||
|
|
||
|
# LCU process
|
||
|
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
||
|
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
||
|
delta = cos - cos1
|
||
|
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
||
|
re_attn2 = delta * re_attn1
|
||
|
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
||
|
re_attn2 = re_attn2 / attn_sum
|
||
|
re_attn2 = focal_equal(re_attn2, query, context, sourceL)
|
||
|
return re_attn2
|
||
|
|
||
|
|
||
|
def focal_equal(attn, query, context, sourceL):
|
||
|
funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True)
|
||
|
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
||
|
torch.zeros_like(attn))
|
||
|
|
||
|
# Step 3: reassign attention
|
||
|
tmp_attn = fattn * attn
|
||
|
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||
|
re_attn = tmp_attn / attn_sum
|
||
|
|
||
|
return re_attn
|
||
|
|
||
|
|
||
|
def correct_prob(attn, query, context, sourceL, g_sim):
|
||
|
"""
|
||
|
consider the confidence g(x) for each fragment as the sqrt
|
||
|
of their similarity probability to the query fragment
|
||
|
sigma_{j} (xi - xj)gj = sigma_{j} xi*gj - sigma_{j} xj*gj
|
||
|
attn: (batch, queryL, sourceL)
|
||
|
"""
|
||
|
# GCU process
|
||
|
d = g_sim - 0.3
|
||
|
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||
|
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||
|
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||
|
re_attn = re_attn / attn_sum
|
||
|
cos1 = cosine_similarity(torch.bmm(re_attn, context), query, dim=-1, keep_dim=True)
|
||
|
cos1 = torch.where(cos1 == 0, cos1.new_full(cos1.shape, 1e-8), cos1)
|
||
|
re_attn1 = focal_prob(re_attn, query, context, sourceL)
|
||
|
|
||
|
# LCU process
|
||
|
cos = cosine_similarity(torch.bmm(re_attn1, context), query, dim=-1, keep_dim=True)
|
||
|
cos = torch.where(cos == 0, cos.new_full(cos.shape, 1e-8), cos)
|
||
|
delta = cos - cos1
|
||
|
delta = torch.where(delta == 0, delta.new_full(delta.shape, 1e-8), delta)
|
||
|
re_attn2 = delta * re_attn1
|
||
|
attn_sum = torch.sum(re_attn2, dim=-1, keepdim=True)
|
||
|
re_attn2 = re_attn2 / attn_sum
|
||
|
re_attn2 = focal_prob(re_attn2, query, context, sourceL)
|
||
|
return re_attn2
|
||
|
|
||
|
|
||
|
def focal_prob(attn, query, context, sourceL):
|
||
|
batch_size, queryL, sourceL = context.size(
|
||
|
0), query.size(1), context.size(1)
|
||
|
|
||
|
# -> (batch, queryL, sourceL, 1)
|
||
|
xi = attn.unsqueeze(-1).contiguous()
|
||
|
# -> (batch, queryL, 1, sourceL)
|
||
|
xj = attn.unsqueeze(2).contiguous()
|
||
|
# -> (batch, queryL, 1, sourceL)
|
||
|
xj_confi = torch.sqrt(xj)
|
||
|
|
||
|
xi = xi.view(batch_size * queryL, sourceL, 1)
|
||
|
xj = xj.view(batch_size * queryL, 1, sourceL)
|
||
|
xj_confi = xj_confi.view(batch_size * queryL, 1, sourceL)
|
||
|
|
||
|
# -> (batch*queryL, sourceL, sourceL)
|
||
|
term1 = torch.bmm(xi, xj_confi).clamp(min=1e-8)
|
||
|
term2 = xj * xj_confi
|
||
|
funcF = torch.sum(term1 - term2, dim=-1) # -> (batch*queryL, sourceL)
|
||
|
funcF = funcF.view(batch_size, queryL, sourceL)
|
||
|
|
||
|
fattn = torch.where(funcF > 0, torch.ones_like(attn),
|
||
|
torch.zeros_like(attn))
|
||
|
|
||
|
# Step 3: reassign attention
|
||
|
tmp_attn = fattn * attn
|
||
|
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||
|
re_attn = tmp_attn / attn_sum
|
||
|
|
||
|
if torch.isnan(re_attn).any():
|
||
|
print("ddd")
|
||
|
return re_attn
|
||
|
|
||
|
# cross attention for MGAN
|
||
|
class CrossAttention(nn.Module):
|
||
|
|
||
|
def __init__(self, opt={}):
|
||
|
super(CrossAttention, self).__init__()
|
||
|
self.att_type = opt['cross_attention']['att_type']
|
||
|
dim = opt['embed']['embed_dim']
|
||
|
|
||
|
channel_size = 512
|
||
|
|
||
|
self.visual_conv = nn.Conv2d(in_channels=384, out_channels=channel_size, kernel_size=1, stride=1)
|
||
|
# visual attention
|
||
|
self.visual_attention = GAT(GATopt(512, 1))
|
||
|
|
||
|
# self.bn_out = nn.BatchNorm1d(512)
|
||
|
# self.dropout_out = nn.Dropout(0.5)
|
||
|
|
||
|
if self.att_type == "soft_att":
|
||
|
self.cross_attention = nn.Sequential(
|
||
|
nn.Linear(dim, dim),
|
||
|
nn.Sigmoid()
|
||
|
)
|
||
|
elif self.att_type == "fusion_att":
|
||
|
self.cross_attention_fc1 = nn.Sequential(
|
||
|
nn.Linear(2*dim, dim),
|
||
|
nn.Sigmoid()
|
||
|
)
|
||
|
self.cross_attention_fc2 = nn.Sequential(
|
||
|
nn.Linear(2*dim, dim),
|
||
|
)
|
||
|
self.cross_attention = lambda x:self.cross_attention_fc1(x)*self.cross_attention_fc2(x)
|
||
|
|
||
|
elif self.att_type == "similarity_att":
|
||
|
self.fc_visual = nn.Sequential(
|
||
|
nn.Linear(dim, dim),
|
||
|
)
|
||
|
self.fc_text = nn.Sequential(
|
||
|
nn.Linear(dim, dim),
|
||
|
)
|
||
|
else:
|
||
|
raise Exception
|
||
|
|
||
|
|
||
|
def forward(self, visual, text):
|
||
|
|
||
|
batch_v = visual.shape[0]
|
||
|
batch_t = text.shape[0]
|
||
|
visual_feature = self.visual_conv(visual)
|
||
|
visual_feature = self.visual_attention(visual_feature)
|
||
|
|
||
|
# [b x 512 x 1 x 1]
|
||
|
visual_feature = F.adaptive_avg_pool2d(visual_feature, 1)
|
||
|
|
||
|
# [b x 512]
|
||
|
visual_feature = visual_feature.squeeze(-1).squeeze(-1)
|
||
|
|
||
|
if self.att_type == "soft_att":
|
||
|
visual_gate = self.cross_attention(visual_feature)
|
||
|
|
||
|
# mm
|
||
|
visual_gate = visual_gate.unsqueeze(dim=1).expand(-1, batch_t, -1)
|
||
|
text = text.unsqueeze(dim=0).expand(batch_v, -1, -1)
|
||
|
|
||
|
return visual_gate*text
|
||
|
|
||
|
elif self.att_type == "fusion_att":
|
||
|
visual = visual_feature.unsqueeze(dim=1).expand(-1, batch_t, -1)
|
||
|
text = text.unsqueeze(dim=0).expand(batch_v, -1, -1)
|
||
|
|
||
|
fusion_vec = torch.cat([visual,text], dim=-1)
|
||
|
|
||
|
return self.cross_attention(fusion_vec)
|
||
|
elif self.att_type == "similarity_att":
|
||
|
visual = self.fc_visual(visual_feature)
|
||
|
text = self.fc_text(text)
|
||
|
|
||
|
visual = visual.unsqueeze(dim=1).expand(-1, batch_t, -1)
|
||
|
text = text.unsqueeze(dim=0).expand(batch_v, -1, -1)
|
||
|
|
||
|
sims = visual * text
|
||
|
text_feature = F.sigmoid(sims) * text
|
||
|
return text_feature
|
||
|
# return l2norm(text_feature, -1)
|
||
|
|
||
|
|
||
|
class MGAM(nn.Module):
|
||
|
def __init__(self, opt = {}):
|
||
|
super(MGAM, self).__init__()
|
||
|
# extract value
|
||
|
channel_size = 256
|
||
|
# sub sample
|
||
|
self.LF_conv = nn.Conv2d(in_channels=240, out_channels=channel_size, kernel_size=2, stride=2) # 240
|
||
|
self.HF_conv = nn.Conv2d(in_channels=384, out_channels=channel_size, kernel_size=1, stride=1) # 512
|
||
|
# visual attention
|
||
|
self.concat_attention = GAT(GATopt(512, 1))
|
||
|
|
||
|
def forward(self, lower_feature, higher_feature, solo_feature):
|
||
|
|
||
|
# b x channel_size x 16 x 16
|
||
|
lower_feature = self.LF_conv(lower_feature)
|
||
|
higher_feature = self.HF_conv(higher_feature)
|
||
|
|
||
|
# concat -> [b x 512 x 7 x 7]
|
||
|
concat_feature = torch.cat([lower_feature, higher_feature], dim=1)
|
||
|
|
||
|
# residual -> [b x 512 x 7 x 7]
|
||
|
concat_feature = higher_feature.mean(dim=1,keepdim=True).expand_as(concat_feature) + concat_feature
|
||
|
|
||
|
# attention - >[b x 512 x 7 x 7]
|
||
|
attent_feature = self.concat_attention(concat_feature)
|
||
|
|
||
|
# [b x 512 x 1 x 1]
|
||
|
attent_feature = F.adaptive_avg_pool2d(attent_feature, 1)
|
||
|
|
||
|
# [b x 512]
|
||
|
attent_feature = attent_feature.squeeze(-1).squeeze(-1)
|
||
|
|
||
|
# solo attention
|
||
|
solo_att = torch.sigmoid(attent_feature)
|
||
|
final_feature = solo_feature * solo_att
|
||
|
|
||
|
return l2norm(final_feature, -1)
|
||
|
|
||
|
|
||
|
class EncoderText(nn.Module):
|
||
|
|
||
|
def __init__(self, word2idx, vocab_size, word_dim = 300, embed_size = 512, num_layers = 1,
|
||
|
use_bi_gru=True, no_txtnorm=False):
|
||
|
super(EncoderText, self).__init__()
|
||
|
self.embed_size = embed_size
|
||
|
self.no_txtnorm = no_txtnorm
|
||
|
|
||
|
# word embedding
|
||
|
self.embed = nn.Embedding(vocab_size, word_dim)
|
||
|
# nn.Embedding.from_pretrained()
|
||
|
|
||
|
# caption embedding
|
||
|
self.use_bi_gru = use_bi_gru
|
||
|
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
|
||
|
|
||
|
# self.init_weights()
|
||
|
self.init_weights(word2idx)
|
||
|
|
||
|
def init_weights(self):
|
||
|
self.embed.weight.data.uniform_(-0.1, 0.1)
|
||
|
|
||
|
def init_weights(self, word2idx):
|
||
|
embeddings_glove = np.load('/home/wzm/crossmodal/embs_npa.npy')
|
||
|
vocab_glove = deserialize_vocab('/home/wzm/crossmodal/GAC/vocab/glove_precomp_vocab.json')
|
||
|
glove_w2i = vocab_glove.word2idx
|
||
|
vocab, embeddings = [], []
|
||
|
|
||
|
# quick-and-dirty trick to improve word-hit rate
|
||
|
missing_words = []
|
||
|
for word, idx in word2idx.items():
|
||
|
# print('idx : ', idx)
|
||
|
if word in glove_w2i:
|
||
|
self.embed.weight.data[idx] = torch.FloatTensor(embeddings_glove[glove_w2i[word]])
|
||
|
# embeddings.append(embeddings_glove[glove_w2i[word]])
|
||
|
else:
|
||
|
missing_words.append(word)
|
||
|
# print(missing_words)
|
||
|
print('Words: {}/{} found in vocabulary; {} words missing'.format(
|
||
|
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
|
||
|
|
||
|
for param in self.embed.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
# vocab_npa = np.array(vocab)
|
||
|
# embs_npa = np.array(embeddings)
|
||
|
|
||
|
# self.embed = torch.nn.Embedding.from_pretrained(torch.from_numpy(embs_npa).float(), freeze=True)
|
||
|
# self.embed.weight.data.uniform_(-0.1, 0.1)
|
||
|
|
||
|
def forward(self, x, lengths):
|
||
|
"""Handles variable size captions
|
||
|
"""
|
||
|
# Embed word ids to vectors
|
||
|
x = self.embed(x)
|
||
|
packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
|
||
|
|
||
|
# Forward propagate RNN
|
||
|
out, _ = self.rnn(packed)
|
||
|
|
||
|
# Reshape *final* output to (batch_size, hidden_size)
|
||
|
padded = pad_packed_sequence(out, batch_first=True)
|
||
|
cap_emb, cap_len = padded
|
||
|
|
||
|
if self.use_bi_gru:
|
||
|
cap_emb = (cap_emb[:,:,:cap_emb.size(2)//2] + cap_emb[:,:,cap_emb.size(2)//2:])/2
|
||
|
|
||
|
# normalization in the joint embedding space
|
||
|
if not self.no_txtnorm:
|
||
|
cap_emb = l2norm(cap_emb, dim=-1)
|
||
|
|
||
|
return cap_emb, cap_len
|
||
|
|
||
|
|
||
|
class BaseModel(nn.Module):
|
||
|
def __init__(self, vocab, opt={}):
|
||
|
super(BaseModel, self).__init__()
|
||
|
|
||
|
# img feature
|
||
|
# self.extract_feature = ExtractFeature(opt = opt)
|
||
|
self.extract_feature = pvig_ti_224_gelu()
|
||
|
self.HF_conv = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1, stride=1) # 512
|
||
|
# vsa feature
|
||
|
self.mgam = MGAM(opt = opt)
|
||
|
# text feature
|
||
|
# self.text_feature = Skipthoughts_Embedding_Module(
|
||
|
# vocab= vocab_words,
|
||
|
# opt = opt
|
||
|
# )
|
||
|
self.text_feature = EncoderText(word2idx = vocab.word2idx, vocab_size = len(vocab))
|
||
|
self.gat_cap = GAT_T(GATopt(512, 1))
|
||
|
|
||
|
self.cross_attention_s = CrossAttention(opt = opt)
|
||
|
|
||
|
self.Eiters = 0
|
||
|
|
||
|
def forward(self, img, captions, lengths=None):
|
||
|
|
||
|
# extract features
|
||
|
lower_feature, higher_feature, solo_feature = self.extract_feature(img)
|
||
|
|
||
|
# mvsa featrues
|
||
|
visual_feature = self.mgam(lower_feature, higher_feature, solo_feature)
|
||
|
# visual_feature = solo_feature
|
||
|
|
||
|
# text features
|
||
|
# print("captions shape :", captions.shape)
|
||
|
# text_feature = self.text_feature(captions)
|
||
|
|
||
|
|
||
|
text_feature, cap_lengths = self.text_feature(captions, lengths)
|
||
|
text_feature = self.gat_cap(text_feature)
|
||
|
# print("text_feature shape :", text_feature.shape)
|
||
|
text_feature = l2norm(torch.mean(text_feature, dim=1))
|
||
|
# print("text_feature shape :", text_feature.shape)
|
||
|
|
||
|
# VGMF
|
||
|
# text_feature = self.aff(higher_feature, text_feature)
|
||
|
# dual_text = self.cross_attention_s(higher_feature, text_feature)
|
||
|
# dual_text = self.cross_attention_s(solo_feature, text_feature)
|
||
|
# Ft = text_feature
|
||
|
# print("Ft size : ", Ft.shape)
|
||
|
|
||
|
# sim dual path
|
||
|
# dual_visual = visual_feature.unsqueeze(dim=1).expand(-1, dual_text.shape[1], -1)
|
||
|
# print("mvsa_feature size : ", mvsa_feature.shape)
|
||
|
|
||
|
# sims = cosine_similarity(dual_visual, dual_text)
|
||
|
sims = cosine_sim(visual_feature, text_feature)
|
||
|
return sims
|
||
|
|
||
|
|
||
|
|
||
|
def factory(opt, vocab, cuda=True, data_parallel=True):
|
||
|
opt = copy.copy(opt)
|
||
|
|
||
|
model = BaseModel(vocab, opt)
|
||
|
# print(model)
|
||
|
if data_parallel:
|
||
|
model = nn.DataParallel(model).cuda()
|
||
|
if not cuda:
|
||
|
raise ValueError
|
||
|
|
||
|
if cuda:
|
||
|
model.cuda()
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
|