326 lines
12 KiB
Python
326 lines
12 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.init
|
||
|
import numpy as np
|
||
|
import copy
|
||
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||
|
from vocab import deserialize_vocab
|
||
|
from model.utils import *
|
||
|
from .pyramid_vig import pvig_ti_224_gelu
|
||
|
from .GAT import GAT, GATopt, GAT_T
|
||
|
|
||
|
def l2norm(X, dim = -1, eps=1e-8):
|
||
|
"""L2-normalize columns of X
|
||
|
"""
|
||
|
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
|
||
|
X = torch.div(X, norm)
|
||
|
return X
|
||
|
|
||
|
def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False):
|
||
|
"""Returns cosine similarity between x1 and x2, computed along dim."""
|
||
|
w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim)
|
||
|
w1 = torch.norm(x1, 2, dim, keepdim=keep_dim)
|
||
|
w2 = torch.norm(x2, 2, dim, keepdim=keep_dim)
|
||
|
if keep_dim:
|
||
|
return w12 / (w1 * w2).clamp(min=eps)
|
||
|
else:
|
||
|
return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1)
|
||
|
|
||
|
def func_attention(query, context, g_sim, opt, eps=1e-8):
|
||
|
"""
|
||
|
query: (batch, queryL, d)
|
||
|
context: (batch, sourceL, d)
|
||
|
opt: parameters
|
||
|
"""
|
||
|
batch_size, queryL, sourceL = context.size(
|
||
|
0), query.size(1), context.size(1)
|
||
|
|
||
|
# Step 1: preassign attention
|
||
|
# --> (batch, d, queryL)
|
||
|
queryT = torch.transpose(query, 1, 2)
|
||
|
|
||
|
# (batch, sourceL, d)(batch, d, queryL)
|
||
|
attn = torch.bmm(context, queryT)
|
||
|
attn = nn.LeakyReLU(0.1)(attn)
|
||
|
attn = l2norm(attn, 2)
|
||
|
|
||
|
# --> (batch, queryL, sourceL)
|
||
|
attn = torch.transpose(attn, 1, 2).contiguous()
|
||
|
# --> (batch*queryL, sourceL)
|
||
|
attn = attn.view(batch_size * queryL, sourceL)
|
||
|
# attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax)
|
||
|
attn = nn.Softmax(dim=1)(attn * 20)
|
||
|
# --> (batch, queryL, sourceL)
|
||
|
attn = attn.view(batch_size, queryL, sourceL)
|
||
|
|
||
|
# Step 2: use g_sim to correct local_sim
|
||
|
re_attn = correct_equal(attn, query, context, sourceL, g_sim, opt['cross_attention']['threshold_p'])
|
||
|
|
||
|
# Step 3: identify irrelevant fragments
|
||
|
re_attn = focal_equal(re_attn, query, context, sourceL, opt['cross_attention']['threshold_q'])
|
||
|
|
||
|
# Step 4: get final local feature
|
||
|
# --> (batch, d, sourceL)
|
||
|
contextT = torch.transpose(context, 1, 2)
|
||
|
# --> (batch, sourceL, queryL)
|
||
|
re_attnT = torch.transpose(re_attn, 1, 2).contiguous()
|
||
|
# (batch x d x sourceL)(batch x sourceL x queryL)
|
||
|
# --> (batch, d, queryL)
|
||
|
weightedContext = torch.bmm(contextT, re_attnT)
|
||
|
# --> (batch, queryL, d)
|
||
|
weightedContext = torch.transpose(weightedContext, 1, 2)
|
||
|
|
||
|
return weightedContext, re_attn
|
||
|
|
||
|
|
||
|
def correct_equal(attn, query, context, sourceL, g_sim, threshold_p):
|
||
|
"""
|
||
|
consider the confidence g(x) for each fragment as equal
|
||
|
sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj
|
||
|
attn: (batch, queryL, sourceL)
|
||
|
"""
|
||
|
# GCU process
|
||
|
d = g_sim - threshold_p
|
||
|
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
|
||
|
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
|
||
|
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
|
||
|
re_attn = re_attn / attn_sum
|
||
|
return re_attn
|
||
|
|
||
|
|
||
|
def focal_equal(attn, query, context, sourceL, threshold_q):
|
||
|
#TODO: i try to choose max sim for hard coding
|
||
|
max_attn = torch.max(attn, dim = -1, keepdim=True)[0]
|
||
|
# funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True)
|
||
|
funcF = max_attn - attn
|
||
|
|
||
|
fattn = torch.where(funcF < threshold_q, torch.ones_like(attn), torch.zeros_like(attn))
|
||
|
|
||
|
# Step 3: reassign attention
|
||
|
tmp_attn = fattn * attn
|
||
|
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
|
||
|
re_attn = tmp_attn / attn_sum
|
||
|
|
||
|
return re_attn
|
||
|
|
||
|
class MGAM(nn.Module):
|
||
|
def __init__(self, opt = {}):
|
||
|
super(MGAM, self).__init__()
|
||
|
# extract value
|
||
|
channel_size = 256
|
||
|
# sub sample
|
||
|
self.LF_conv = nn.Conv2d(in_channels=240, out_channels=channel_size, kernel_size=2, stride=2) # 240
|
||
|
self.HF_conv = nn.Conv2d(in_channels=384, out_channels=channel_size, kernel_size=1, stride=1) # 512
|
||
|
# visual attention
|
||
|
self.concat_attention = GAT(GATopt(512, 1))
|
||
|
|
||
|
def forward(self, lower_feature, higher_feature, solo_feature):
|
||
|
|
||
|
# b x channel_size x 16 x 16
|
||
|
lower_feature = self.LF_conv(lower_feature)
|
||
|
higher_feature = self.HF_conv(higher_feature)
|
||
|
|
||
|
# concat -> [b x 512 x 7 x 7]
|
||
|
concat_feature = torch.cat([lower_feature, higher_feature], dim=1)
|
||
|
|
||
|
# residual -> [b x 512 x 7 x 7]
|
||
|
concat_feature = higher_feature.mean(dim=1,keepdim=True).expand_as(concat_feature) + concat_feature
|
||
|
|
||
|
# attention - >[b x 512 x 7 x 7]
|
||
|
attent_feature = self.concat_attention(concat_feature)
|
||
|
visual_embs = attent_feature.view(attent_feature.size(0), 512, -1).transpose(1, 2)
|
||
|
|
||
|
# [b x 512 x 1 x 1]
|
||
|
attent_feature = F.adaptive_avg_pool2d(attent_feature, 1)
|
||
|
|
||
|
# [b x 512]
|
||
|
attent_feature = attent_feature.squeeze(-1).squeeze(-1)
|
||
|
|
||
|
# solo attention
|
||
|
solo_att = torch.sigmoid(attent_feature)
|
||
|
final_feature = solo_feature * solo_att
|
||
|
|
||
|
return visual_embs, l2norm(final_feature, -1)
|
||
|
|
||
|
class EncoderText(nn.Module):
|
||
|
def __init__(self, word2idx, vocab_size, word_dim = 300, embed_size = 512, num_layers = 1,
|
||
|
use_bi_gru=True, no_txtnorm=False):
|
||
|
super(EncoderText, self).__init__()
|
||
|
self.embed_size = embed_size
|
||
|
self.no_txtnorm = no_txtnorm
|
||
|
|
||
|
# word embedding
|
||
|
self.embed = nn.Embedding(vocab_size, word_dim)
|
||
|
# nn.Embedding.from_pretrained()
|
||
|
|
||
|
# caption embedding
|
||
|
self.use_bi_gru = use_bi_gru
|
||
|
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
|
||
|
|
||
|
# self.init_weights()
|
||
|
self.init_weights(word2idx)
|
||
|
|
||
|
def init_weights(self):
|
||
|
self.embed.weight.data.uniform_(-0.1, 0.1)
|
||
|
|
||
|
def init_weights(self, word2idx):
|
||
|
embeddings_glove = np.load('/home/wzm/embs_npa.npy')
|
||
|
vocab_glove = deserialize_vocab('/home/wzm/GAC/vocab/glove_precomp_vocab.json')
|
||
|
glove_w2i = vocab_glove.word2idx
|
||
|
vocab, embeddings = [], []
|
||
|
|
||
|
# quick-and-dirty trick to improve word-hit rate
|
||
|
missing_words = []
|
||
|
for word, idx in word2idx.items():
|
||
|
# print('idx : ', idx)
|
||
|
if word in glove_w2i:
|
||
|
self.embed.weight.data[idx] = torch.FloatTensor(embeddings_glove[glove_w2i[word]])
|
||
|
# embeddings.append(embeddings_glove[glove_w2i[word]])
|
||
|
else:
|
||
|
missing_words.append(word)
|
||
|
# print(missing_words)
|
||
|
print('Words: {}/{} found in vocabulary; {} words missing'.format(
|
||
|
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
|
||
|
|
||
|
for param in self.embed.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
# vocab_npa = np.array(vocab)
|
||
|
# embs_npa = np.array(embeddings)
|
||
|
|
||
|
# self.embed = torch.nn.Embedding.from_pretrained(torch.from_numpy(embs_npa).float(), freeze=True)
|
||
|
# self.embed.weight.data.uniform_(-0.1, 0.1)
|
||
|
|
||
|
def forward(self, x, lengths):
|
||
|
"""Handles variable size captions
|
||
|
"""
|
||
|
# Embed word ids to vectors
|
||
|
x = self.embed(x)
|
||
|
packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
|
||
|
|
||
|
# Forward propagate RNN
|
||
|
out, _ = self.rnn(packed)
|
||
|
|
||
|
# Reshape *final* output to (batch_size, hidden_size)
|
||
|
padded = pad_packed_sequence(out, batch_first=True)
|
||
|
cap_emb, cap_len = padded
|
||
|
|
||
|
if self.use_bi_gru:
|
||
|
cap_emb = (cap_emb[:,:,:cap_emb.size(2)//2] + cap_emb[:,:,cap_emb.size(2)//2:])/2
|
||
|
|
||
|
# normalization in the joint embedding space
|
||
|
if not self.no_txtnorm:
|
||
|
cap_emb = l2norm(cap_emb, dim=-1)
|
||
|
|
||
|
return cap_emb, cap_len
|
||
|
|
||
|
class BaseModel(nn.Module):
|
||
|
def __init__(self, vocab, opt={}):
|
||
|
super(BaseModel, self).__init__()
|
||
|
self.opt = opt
|
||
|
|
||
|
# --------------- img feature ------------
|
||
|
# self.extract_feature = ExtractFeature(opt = opt)
|
||
|
self.extract_feature = pvig_ti_224_gelu()
|
||
|
self.HF_conv = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1, stride=1) # 512
|
||
|
# vsa feature
|
||
|
self.mgam = MGAM(opt = opt)
|
||
|
|
||
|
# --------------- text feature -----------
|
||
|
self.text_feature = EncoderText(word2idx = vocab.word2idx, vocab_size = len(vocab))
|
||
|
self.gat_cap = GAT_T(GATopt(512, 1))
|
||
|
|
||
|
self.Eiters = 0
|
||
|
|
||
|
def forward_emb(self, images, captions, lengths=None):
|
||
|
|
||
|
# extract features
|
||
|
lower_feature, higher_feature, solo_feature = self.extract_feature(images)
|
||
|
|
||
|
# mvsa featrues
|
||
|
img_embs, img_solo = self.mgam(lower_feature, higher_feature, solo_feature)
|
||
|
# img_embs, img_solo = [], solo_feature
|
||
|
# visual_feature = solo_feature
|
||
|
|
||
|
# text features
|
||
|
cap_embs, cap_lens = self.text_feature(captions, lengths)
|
||
|
cap_embs = self.gat_cap(cap_embs)
|
||
|
# print("cap_embs shape :", cap_embs.shape)
|
||
|
cap_solo = l2norm(torch.mean(cap_embs, dim=1))
|
||
|
# print("cap_solo shape :", cap_solo.shape)
|
||
|
return img_embs, img_solo, cap_embs, cap_lens, cap_solo
|
||
|
|
||
|
def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs):
|
||
|
"""Compute the loss given pairs of image and caption embeddings
|
||
|
"""
|
||
|
scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean)
|
||
|
# # # TODO: add global sims
|
||
|
# sims = cosine_sim(img_mean, cap_mean)
|
||
|
# return scores + sims
|
||
|
return scores
|
||
|
# ------- no cross attention ----------
|
||
|
# sims = cosine_sim(img_mean, cap_mean)
|
||
|
# return sims
|
||
|
|
||
|
def forward(self, images, captions, lengths, ids=None, *args):
|
||
|
# compute the embeddings
|
||
|
# lengths = lengths.cpu().numpy().tolist()
|
||
|
img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, captions, lengths)
|
||
|
scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean)
|
||
|
return scores
|
||
|
|
||
|
def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean):
|
||
|
similarities = []
|
||
|
n_image = images.size(0)
|
||
|
n_caption = captions.size(0)
|
||
|
g_sims = cap_mean.mm(img_mean.t())
|
||
|
for i in range(n_caption):
|
||
|
# Get the i-th text description
|
||
|
n_word = cap_lens[i]
|
||
|
g_sim = g_sims[i]
|
||
|
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
|
||
|
# --> (n_image, n_word, d)
|
||
|
cap_i_expand = cap_i.repeat(n_image, 1, 1)
|
||
|
|
||
|
# t2i process
|
||
|
# weiContext: (n_image, n_word, d)
|
||
|
weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt)
|
||
|
t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
|
||
|
t2i_sim = t2i_sim.mean(dim=1, keepdim=True)
|
||
|
|
||
|
# i2t process
|
||
|
# weiContext: (n_image, n_word, d)
|
||
|
weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt)
|
||
|
i2t_sim = cosine_similarity(images, weiContext, dim=2)
|
||
|
i2t_sim = i2t_sim.mean(dim=1, keepdim=True)
|
||
|
|
||
|
# Overall similarity for image and text
|
||
|
sim = t2i_sim + i2t_sim
|
||
|
|
||
|
similarities.append(sim)
|
||
|
|
||
|
# (n_image, n_caption)
|
||
|
similarities = torch.cat(similarities, 1)
|
||
|
|
||
|
if self.training:
|
||
|
similarities = similarities.transpose(0, 1)
|
||
|
|
||
|
return similarities
|
||
|
|
||
|
|
||
|
|
||
|
def factory(opt, vocab, cuda=True, data_parallel=True):
|
||
|
opt = copy.copy(opt)
|
||
|
model = BaseModel(vocab, opt)
|
||
|
# print(model)
|
||
|
if data_parallel:
|
||
|
model = nn.DataParallel(model).cuda()
|
||
|
if not cuda:
|
||
|
raise ValueError
|
||
|
if cuda:
|
||
|
model.cuda()
|
||
|
return model
|
||
|
|
||
|
|
||
|
|