59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Variable
|
|
|
|
|
|
class ContrastiveLoss(nn.Module):
|
|
"""
|
|
Compute contrastive loss (max-margin based)
|
|
"""
|
|
|
|
def __init__(self, opt, margin=0, max_violation=False):
|
|
super(ContrastiveLoss, self).__init__()
|
|
self.opt = opt
|
|
self.margin = margin
|
|
self.max_violation = max_violation
|
|
|
|
def max_violation_on(self):
|
|
self.max_violation = True
|
|
print('Use VSE++ objective.')
|
|
|
|
def max_violation_off(self):
|
|
self.max_violation = False
|
|
print('Use VSE0 objective.')
|
|
|
|
def forward(self, im, s):
|
|
# compute image-sentence score matrix
|
|
scores = get_sim(im, s)
|
|
diagonal = scores.diag().view(im.size(0), 1)
|
|
d1 = diagonal.expand_as(scores)
|
|
d2 = diagonal.t().expand_as(scores)
|
|
|
|
# compare every diagonal score to scores in its column
|
|
# caption retrieval
|
|
cost_s = (self.margin + scores - d1).clamp(min=0)
|
|
# compare every diagonal score to scores in its row
|
|
# image retrieval
|
|
cost_im = (self.margin + scores - d2).clamp(min=0)
|
|
|
|
# clear diagonals
|
|
mask = torch.eye(scores.size(0)) > .5
|
|
I = Variable(mask)
|
|
if torch.cuda.is_available():
|
|
I = I.cuda()
|
|
cost_s = cost_s.masked_fill_(I, 0)
|
|
cost_im = cost_im.masked_fill_(I, 0)
|
|
|
|
# keep the maximum violating negative for each query
|
|
if self.max_violation:
|
|
cost_s = cost_s.max(1)[0]
|
|
cost_im = cost_im.max(0)[0]
|
|
|
|
return cost_s.sum() + cost_im.sum()
|
|
|
|
|
|
def get_sim(images, captions):
|
|
similarities = images.mm(captions.t())
|
|
return similarities
|
|
|