Graduation_Project/WZM/model/GAC.py

326 lines
12 KiB
Python
Raw Normal View History

2024-06-24 19:41:48 +08:00
import torch
import torch.nn as nn
import torch.nn.init
import numpy as np
import copy
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from vocab import deserialize_vocab
from model.utils import *
from .pyramid_vig import pvig_ti_224_gelu
from .GAT import GAT, GATopt, GAT_T
def l2norm(X, dim = -1, eps=1e-8):
"""L2-normalize columns of X
"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
def cosine_similarity(x1, x2, dim=1, eps=1e-8, keep_dim=False):
"""Returns cosine similarity between x1 and x2, computed along dim."""
w12 = torch.sum(x1 * x2, dim, keepdim=keep_dim)
w1 = torch.norm(x1, 2, dim, keepdim=keep_dim)
w2 = torch.norm(x2, 2, dim, keepdim=keep_dim)
if keep_dim:
return w12 / (w1 * w2).clamp(min=eps)
else:
return (w12 / (w1 * w2).clamp(min=eps)).squeeze(-1)
def func_attention(query, context, g_sim, opt, eps=1e-8):
"""
query: (batch, queryL, d)
context: (batch, sourceL, d)
opt: parameters
"""
batch_size, queryL, sourceL = context.size(
0), query.size(1), context.size(1)
# Step 1: preassign attention
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)
# (batch, sourceL, d)(batch, d, queryL)
attn = torch.bmm(context, queryT)
attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)
# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch*queryL, sourceL)
attn = attn.view(batch_size * queryL, sourceL)
# attn = nn.Softmax(dim=1)(attn * opt.lambda_softmax)
attn = nn.Softmax(dim=1)(attn * 20)
# --> (batch, queryL, sourceL)
attn = attn.view(batch_size, queryL, sourceL)
# Step 2: use g_sim to correct local_sim
re_attn = correct_equal(attn, query, context, sourceL, g_sim, opt['cross_attention']['threshold_p'])
# Step 3: identify irrelevant fragments
re_attn = focal_equal(re_attn, query, context, sourceL, opt['cross_attention']['threshold_q'])
# Step 4: get final local feature
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# --> (batch, sourceL, queryL)
re_attnT = torch.transpose(re_attn, 1, 2).contiguous()
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, re_attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)
return weightedContext, re_attn
def correct_equal(attn, query, context, sourceL, g_sim, threshold_p):
"""
consider the confidence g(x) for each fragment as equal
sigma_{j} (xi - xj) = sigma_{j} xi - sigma_{j} xj
attn: (batch, queryL, sourceL)
"""
# GCU process
d = g_sim - threshold_p
d = torch.where(d == 0, d.new_full(d.shape, 1e-8), d)
re_attn = d.unsqueeze(1).unsqueeze(2) * attn
attn_sum = torch.sum(re_attn, dim=-1, keepdim=True)
re_attn = re_attn / attn_sum
return re_attn
def focal_equal(attn, query, context, sourceL, threshold_q):
#TODO: i try to choose max sim for hard coding
max_attn = torch.max(attn, dim = -1, keepdim=True)[0]
# funcF = attn * sourceL - torch.sum(attn, dim=-1, keepdim=True)
funcF = max_attn - attn
fattn = torch.where(funcF < threshold_q, torch.ones_like(attn), torch.zeros_like(attn))
# Step 3: reassign attention
tmp_attn = fattn * attn
attn_sum = torch.sum(tmp_attn, dim=-1, keepdim=True)
re_attn = tmp_attn / attn_sum
return re_attn
class MGAM(nn.Module):
def __init__(self, opt = {}):
super(MGAM, self).__init__()
# extract value
channel_size = 256
# sub sample
self.LF_conv = nn.Conv2d(in_channels=240, out_channels=channel_size, kernel_size=2, stride=2) # 240
self.HF_conv = nn.Conv2d(in_channels=384, out_channels=channel_size, kernel_size=1, stride=1) # 512
# visual attention
self.concat_attention = GAT(GATopt(512, 1))
def forward(self, lower_feature, higher_feature, solo_feature):
# b x channel_size x 16 x 16
lower_feature = self.LF_conv(lower_feature)
higher_feature = self.HF_conv(higher_feature)
# concat -> [b x 512 x 7 x 7]
concat_feature = torch.cat([lower_feature, higher_feature], dim=1)
# residual -> [b x 512 x 7 x 7]
concat_feature = higher_feature.mean(dim=1,keepdim=True).expand_as(concat_feature) + concat_feature
# attention - >[b x 512 x 7 x 7]
attent_feature = self.concat_attention(concat_feature)
visual_embs = attent_feature.view(attent_feature.size(0), 512, -1).transpose(1, 2)
# [b x 512 x 1 x 1]
attent_feature = F.adaptive_avg_pool2d(attent_feature, 1)
# [b x 512]
attent_feature = attent_feature.squeeze(-1).squeeze(-1)
# solo attention
solo_att = torch.sigmoid(attent_feature)
final_feature = solo_feature * solo_att
return visual_embs, l2norm(final_feature, -1)
class EncoderText(nn.Module):
def __init__(self, word2idx, vocab_size, word_dim = 300, embed_size = 512, num_layers = 1,
use_bi_gru=True, no_txtnorm=False):
super(EncoderText, self).__init__()
self.embed_size = embed_size
self.no_txtnorm = no_txtnorm
# word embedding
self.embed = nn.Embedding(vocab_size, word_dim)
# nn.Embedding.from_pretrained()
# caption embedding
self.use_bi_gru = use_bi_gru
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
# self.init_weights()
self.init_weights(word2idx)
def init_weights(self):
self.embed.weight.data.uniform_(-0.1, 0.1)
def init_weights(self, word2idx):
embeddings_glove = np.load('/home/wzm/embs_npa.npy')
vocab_glove = deserialize_vocab('/home/wzm/GAC/vocab/glove_precomp_vocab.json')
glove_w2i = vocab_glove.word2idx
vocab, embeddings = [], []
# quick-and-dirty trick to improve word-hit rate
missing_words = []
for word, idx in word2idx.items():
# print('idx : ', idx)
if word in glove_w2i:
self.embed.weight.data[idx] = torch.FloatTensor(embeddings_glove[glove_w2i[word]])
# embeddings.append(embeddings_glove[glove_w2i[word]])
else:
missing_words.append(word)
# print(missing_words)
print('Words: {}/{} found in vocabulary; {} words missing'.format(
len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
for param in self.embed.parameters():
param.requires_grad = False
# vocab_npa = np.array(vocab)
# embs_npa = np.array(embeddings)
# self.embed = torch.nn.Embedding.from_pretrained(torch.from_numpy(embs_npa).float(), freeze=True)
# self.embed.weight.data.uniform_(-0.1, 0.1)
def forward(self, x, lengths):
"""Handles variable size captions
"""
# Embed word ids to vectors
x = self.embed(x)
packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
# Forward propagate RNN
out, _ = self.rnn(packed)
# Reshape *final* output to (batch_size, hidden_size)
padded = pad_packed_sequence(out, batch_first=True)
cap_emb, cap_len = padded
if self.use_bi_gru:
cap_emb = (cap_emb[:,:,:cap_emb.size(2)//2] + cap_emb[:,:,cap_emb.size(2)//2:])/2
# normalization in the joint embedding space
if not self.no_txtnorm:
cap_emb = l2norm(cap_emb, dim=-1)
return cap_emb, cap_len
class BaseModel(nn.Module):
def __init__(self, vocab, opt={}):
super(BaseModel, self).__init__()
self.opt = opt
# --------------- img feature ------------
# self.extract_feature = ExtractFeature(opt = opt)
self.extract_feature = pvig_ti_224_gelu()
self.HF_conv = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1, stride=1) # 512
# vsa feature
self.mgam = MGAM(opt = opt)
# --------------- text feature -----------
self.text_feature = EncoderText(word2idx = vocab.word2idx, vocab_size = len(vocab))
self.gat_cap = GAT_T(GATopt(512, 1))
self.Eiters = 0
def forward_emb(self, images, captions, lengths=None):
# extract features
lower_feature, higher_feature, solo_feature = self.extract_feature(images)
# mvsa featrues
img_embs, img_solo = self.mgam(lower_feature, higher_feature, solo_feature)
# img_embs, img_solo = [], solo_feature
# visual_feature = solo_feature
# text features
cap_embs, cap_lens = self.text_feature(captions, lengths)
cap_embs = self.gat_cap(cap_embs)
# print("cap_embs shape :", cap_embs.shape)
cap_solo = l2norm(torch.mean(cap_embs, dim=1))
# print("cap_solo shape :", cap_solo.shape)
return img_embs, img_solo, cap_embs, cap_lens, cap_solo
def forward_sim(self, img_emb, img_mean, cap_emb, cap_len, cap_mean, **kwargs):
"""Compute the loss given pairs of image and caption embeddings
"""
scores = self.xattn_score(img_emb, img_mean, cap_emb, cap_len, cap_mean)
# # # TODO: add global sims
# sims = cosine_sim(img_mean, cap_mean)
# return scores + sims
return scores
# ------- no cross attention ----------
# sims = cosine_sim(img_mean, cap_mean)
# return sims
def forward(self, images, captions, lengths, ids=None, *args):
# compute the embeddings
# lengths = lengths.cpu().numpy().tolist()
img_emb, img_mean, cap_emb, cap_lens, cap_mean = self.forward_emb(images, captions, lengths)
scores = self.forward_sim(img_emb, img_mean, cap_emb, cap_lens, cap_mean)
return scores
def xattn_score(self, images, img_mean, captions, cap_lens, cap_mean):
similarities = []
n_image = images.size(0)
n_caption = captions.size(0)
g_sims = cap_mean.mm(img_mean.t())
for i in range(n_caption):
# Get the i-th text description
n_word = cap_lens[i]
g_sim = g_sims[i]
cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous()
# --> (n_image, n_word, d)
cap_i_expand = cap_i.repeat(n_image, 1, 1)
# t2i process
# weiContext: (n_image, n_word, d)
weiContext, _ = func_attention(cap_i_expand, images, g_sim, self.opt)
t2i_sim = cosine_similarity(cap_i_expand, weiContext, dim=2)
t2i_sim = t2i_sim.mean(dim=1, keepdim=True)
# i2t process
# weiContext: (n_image, n_word, d)
weiContext, _ = func_attention(images, cap_i_expand, g_sim, self.opt)
i2t_sim = cosine_similarity(images, weiContext, dim=2)
i2t_sim = i2t_sim.mean(dim=1, keepdim=True)
# Overall similarity for image and text
sim = t2i_sim + i2t_sim
similarities.append(sim)
# (n_image, n_caption)
similarities = torch.cat(similarities, 1)
if self.training:
similarities = similarities.transpose(0, 1)
return similarities
def factory(opt, vocab, cuda=True, data_parallel=True):
opt = copy.copy(opt)
model = BaseModel(vocab, opt)
# print(model)
if data_parallel:
model = nn.DataParallel(model).cuda()
if not cuda:
raise ValueError
if cuda:
model.cuda()
return model