Graduation_Project/QN/RecipeRetrieval/models/loss.py

13 lines
488 B
Python

import torch
class ContrastiveLoss(nn.Module):
def __init__(self, margin):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
# Example implementation for triplet loss
positive_distance = (anchor - positive).pow(2).sum(1)
negative_distance = (anchor - negative).pow(2).sum(1)
loss = torch.relu(positive_distance - negative_distance + self.margin)
return loss.mean()