Graduation_Project/LHL/lib/loss.py

59 lines
1.7 KiB
Python
Raw Permalink Normal View History

2024-06-25 11:50:04 +08:00
import torch
import torch.nn as nn
from torch.autograd import Variable
class ContrastiveLoss(nn.Module):
"""
Compute contrastive loss (max-margin based)
"""
def __init__(self, opt, margin=0, max_violation=False):
super(ContrastiveLoss, self).__init__()
self.opt = opt
self.margin = margin
self.max_violation = max_violation
def max_violation_on(self):
self.max_violation = True
print('Use VSE++ objective.')
def max_violation_off(self):
self.max_violation = False
print('Use VSE0 objective.')
def forward(self, im, s):
# compute image-sentence score matrix
scores = get_sim(im, s)
diagonal = scores.diag().view(im.size(0), 1)
d1 = diagonal.expand_as(scores)
d2 = diagonal.t().expand_as(scores)
# compare every diagonal score to scores in its column
# caption retrieval
cost_s = (self.margin + scores - d1).clamp(min=0)
# compare every diagonal score to scores in its row
# image retrieval
cost_im = (self.margin + scores - d2).clamp(min=0)
# clear diagonals
mask = torch.eye(scores.size(0)) > .5
I = Variable(mask)
if torch.cuda.is_available():
I = I.cuda()
cost_s = cost_s.masked_fill_(I, 0)
cost_im = cost_im.masked_fill_(I, 0)
# keep the maximum violating negative for each query
if self.max_violation:
cost_s = cost_s.max(1)[0]
cost_im = cost_im.max(0)[0]
return cost_s.sum() + cost_im.sum()
def get_sim(images, captions):
similarities = images.mm(captions.t())
return similarities