13 lines
488 B
Python
13 lines
488 B
Python
import torch
|
|
|
|
class ContrastiveLoss(nn.Module):
|
|
def __init__(self, margin):
|
|
super(ContrastiveLoss, self).__init__()
|
|
self.margin = margin
|
|
|
|
def forward(self, anchor, positive, negative):
|
|
# Example implementation for triplet loss
|
|
positive_distance = (anchor - positive).pow(2).sum(1)
|
|
negative_distance = (anchor - negative).pow(2).sum(1)
|
|
loss = torch.relu(positive_distance - negative_distance + self.margin)
|
|
return loss.mean() |