17 lines
519 B
Python
17 lines
519 B
Python
|
|
import torch
|
|
import clip
|
|
|
|
|
|
class CLIPLoss(torch.nn.Module):
|
|
|
|
def __init__(self, opts):
|
|
super(CLIPLoss, self).__init__()
|
|
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
|
self.upsample = torch.nn.Upsample(scale_factor=7)
|
|
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
|
|
|
|
def forward(self, image, text):
|
|
image = self.avg_pool(self.upsample(image))
|
|
similarity = 1 - self.model(image, text)[0] / 100
|
|
return similarity |