Graduation_Project/LYZ/loss.py

25 lines
426 B
Python
Raw Normal View History

2024-06-29 14:23:18 +08:00
import torch
import torch.nn as nn
class Loss(nn.Module):
def __init__(self):
super().__init__()
self.loss = nn.BCELoss()
def forward(self, x, y):
return self.loss(x, y)
if __name__ == '__main__':
batch_size = 3
x = torch.rand(batch_size, 1).cuda()
y = torch.rand(batch_size, 1).cuda()
criterion = Loss()
loss = criterion(x, y)
print(loss.shape)
print(loss)