268 lines
12 KiB
Python
268 lines
12 KiB
Python
|
import argparse
|
|||
|
import math
|
|||
|
import os
|
|||
|
import pickle
|
|||
|
|
|||
|
import torchvision
|
|||
|
from torch import optim
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
import torch
|
|||
|
import clip
|
|||
|
|
|||
|
|
|||
|
class CLIPLoss(torch.nn.Module):
|
|||
|
|
|||
|
def __init__(self, opts):
|
|||
|
super(CLIPLoss, self).__init__()
|
|||
|
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
|||
|
self.upsample = torch.nn.Upsample(scale_factor=7)
|
|||
|
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
|
|||
|
|
|||
|
def forward(self, image, text):
|
|||
|
image = self.avg_pool(self.upsample(image))
|
|||
|
similarity = 1 - self.model(image, text)[0] / 100
|
|||
|
return similarity
|
|||
|
|
|||
|
|
|||
|
from torch import nn
|
|||
|
import sys
|
|||
|
sys.path.append('/home/ly/StyleCLIP-main/models/facial_recognition')
|
|||
|
from model_irse import Backbone
|
|||
|
|
|||
|
|
|||
|
class IDLoss(nn.Module):
|
|||
|
def __init__(self, opts):
|
|||
|
super(IDLoss, self).__init__()
|
|||
|
print('Loading ResNet ArcFace')
|
|||
|
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
|||
|
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
|
|||
|
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
|||
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
|||
|
self.facenet.eval()
|
|||
|
self.facenet.cuda()
|
|||
|
self.opts = opts
|
|||
|
|
|||
|
def extract_feats(self, x):
|
|||
|
if x.shape[2] != 256:
|
|||
|
x = self.pool(x)
|
|||
|
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
|||
|
x = self.face_pool(x)
|
|||
|
x_feats = self.facenet(x)
|
|||
|
return x_feats
|
|||
|
|
|||
|
def forward(self, y_hat, y):
|
|||
|
n_samples = y.shape[0]
|
|||
|
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
|||
|
y_hat_feats = self.extract_feats(y_hat)
|
|||
|
y_feats = y_feats.detach()
|
|||
|
loss = 0
|
|||
|
sim_improvement = 0
|
|||
|
count = 0
|
|||
|
for i in range(n_samples):
|
|||
|
diff_target = y_hat_feats[i].dot(y_feats[i])
|
|||
|
loss += 1 - diff_target
|
|||
|
count += 1
|
|||
|
|
|||
|
return loss / count, sim_improvement / count
|
|||
|
sys.path.append('/home/ly/StyleCLIP-main/mapper/training')
|
|||
|
from train_utils import STYLESPACE_DIMENSIONS
|
|||
|
from model_3 import Generator
|
|||
|
from model_3 import SynthesisNetwork
|
|||
|
from model_3 import SynthesisLayer
|
|||
|
|
|||
|
|
|||
|
sys.path.append('/home/ly/StyleCLIP-main')
|
|||
|
from utils import ensure_checkpoint_exists
|
|||
|
|
|||
|
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in list(range(1, len(STYLESPACE_DIMENSIONS), 3))]
|
|||
|
|
|||
|
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
|||
|
lr_ramp = min(1, (1 - t) / rampdown)
|
|||
|
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
|||
|
lr_ramp = lr_ramp * min(1, t / rampup)
|
|||
|
|
|||
|
return initial_lr * lr_ramp
|
|||
|
|
|||
|
|
|||
|
def main(args):
|
|||
|
ensure_checkpoint_exists(args.ckpt)
|
|||
|
# 把描述加载进clip预训练模型里面去
|
|||
|
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
|
|||
|
# print('text_input是: ', text_inputs)
|
|||
|
#tokenizer clip分词的机制 依据规则
|
|||
|
#以及词汇表的总量
|
|||
|
'''
|
|||
|
--description "a person with purple hair"
|
|||
|
tensor([[49406, 320, 2533, 593, 5496, 2225, 49407, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
|
|||
|
dtype=torch.int32)
|
|||
|
--description "a person with red hair"
|
|||
|
tensor([[49406, 320, 2533, 593, 736, 2225, 49407, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|||
|
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
|
|||
|
dtype=torch.int32)
|
|||
|
'''
|
|||
|
|
|||
|
os.makedirs(args.results_dir, exist_ok=True)
|
|||
|
#改成stylegan3的输入
|
|||
|
|
|||
|
# with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f:
|
|||
|
# G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
|
|||
|
# z = torch.randn([1, G.z_dim]).cuda() # latent codes
|
|||
|
# c = None # class labels (not used in this example)
|
|||
|
# img = G(z, c) # NCHW, float32, dynamic range [-1, +1], no truncation
|
|||
|
|
|||
|
# g_ema = Generator(512, 0, 512,args.stylegan_size, 3) #512,0,512,1024,3
|
|||
|
# with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f:
|
|||
|
#stylegan3-r-ffhqu-1024x1024.pkl 生成图片的效果欠佳 别用
|
|||
|
#stylegan3-t-ffhq-1024x1024.pkl 生成效果一般 loss值较好
|
|||
|
#stylegan3-r-ffhq-1024x1024.pkl 折中
|
|||
|
#stylegan3-t-ffhqu-1024x1024.pkl 生成图片可以 loss较差
|
|||
|
with open('/home/ly/StyleCLIP-main/pretrained_models/stylegan3-t-ffhq-1024x1024.pkl', 'rb') as f: #stylespace_dimensions [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 32, 32]
|
|||
|
# new_p = pickle.load(f)
|
|||
|
# print(new_p)
|
|||
|
# print("new_p")
|
|||
|
# print(new_p.keys())
|
|||
|
# G_ema.load_state_dict(pickle.load(f)['G_ema'].cuda(), strict=False) 这种方式模型加载不进来
|
|||
|
g_ema = pickle.load(f)['G_ema'].cuda() # torch.nn.Module 这种方式推演三百步的图片平均要4分钟
|
|||
|
z = torch.randn([1, g_ema.z_dim]).cuda() # latent codes
|
|||
|
c = None # class labels (not used in this example)
|
|||
|
#g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
|
|||
|
# 将模型对象设置为评估模式
|
|||
|
g_ema.eval()
|
|||
|
#更改cuda卡号
|
|||
|
g_ema = g_ema.cuda()
|
|||
|
# device = torch.cuda.current_device()
|
|||
|
# print('cuda:',device)
|
|||
|
mean_latent = torch.randn([1, g_ema.z_dim]).cuda()
|
|||
|
torch.save(mean_latent,'/home/ly/StyleCLIP-main/pretrained_models/latent_code/style3.pt')
|
|||
|
# print('mean_latent: ', mean_latent)
|
|||
|
|
|||
|
if args.latent_path:
|
|||
|
latent_code_init = torch.load(args.latent_path).cuda()
|
|||
|
# elif args.mode == "edit":
|
|||
|
# latent_code_init_not_trunc = torch.randn(1, 512).cuda()
|
|||
|
# with torch.no_grad():
|
|||
|
# _, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True,
|
|||
|
# truncation=args.truncation, truncation_latent=mean_latent)
|
|||
|
else:
|
|||
|
# latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1) #在维度1上重复18次
|
|||
|
latent_code_init = mean_latent.detach().clone()
|
|||
|
# def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
|
|||
|
with torch.no_grad():
|
|||
|
print("mean_latent ", mean_latent.shape)
|
|||
|
# img_orig, _ = g_ema([latent_code_init], c, input_is_latent=True, randomize_noise=False)
|
|||
|
img_orig = g_ema(latent_code_init, c)
|
|||
|
|
|||
|
if args.work_in_stylespace:
|
|||
|
with torch.no_grad():
|
|||
|
_, _, latent_code_init = g_ema([latent_code_init], input_is_latent=True, return_latents=True)
|
|||
|
latent = [s.detach().clone() for s in latent_code_init]
|
|||
|
for c, s in enumerate(latent):
|
|||
|
if c in STYLESPACE_INDICES_WITHOUT_TORGB:
|
|||
|
s.requires_grad = True
|
|||
|
else:
|
|||
|
latent = latent_code_init.detach().clone()
|
|||
|
latent.requires_grad = True
|
|||
|
|
|||
|
clip_loss = CLIPLoss(args)
|
|||
|
id_loss = IDLoss(args)
|
|||
|
|
|||
|
if args.work_in_stylespace:
|
|||
|
optimizer = optim.Adam(latent, lr=args.lr)
|
|||
|
else:
|
|||
|
optimizer = optim.Adam([latent], lr=args.lr)
|
|||
|
|
|||
|
pbar = tqdm(range(args.step))
|
|||
|
|
|||
|
for i in pbar:
|
|||
|
t = i / args.step
|
|||
|
lr = get_lr(t, args.lr)
|
|||
|
optimizer.param_groups[0]["lr"] = lr
|
|||
|
|
|||
|
img_gen = g_ema(latent,c)
|
|||
|
|
|||
|
c_loss = clip_loss(img_gen, text_inputs)
|
|||
|
|
|||
|
if args.id_lambda > 0:
|
|||
|
#身份损失
|
|||
|
i_loss = id_loss(img_gen, img_orig)[0]
|
|||
|
else:
|
|||
|
i_loss = 0
|
|||
|
|
|||
|
if args.mode == "edit":
|
|||
|
if args.work_in_stylespace:
|
|||
|
l2_loss = sum([((latent_code_init[c] - latent[c]) ** 2).sum() for c in range(len(latent_code_init))])
|
|||
|
else:
|
|||
|
#与潜在空间的L2距离
|
|||
|
l2_loss = ((latent_code_init - latent) ** 2).sum()
|
|||
|
loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
|
|||
|
else:
|
|||
|
loss = c_loss
|
|||
|
|
|||
|
optimizer.zero_grad()
|
|||
|
loss.backward()
|
|||
|
optimizer.step()
|
|||
|
|
|||
|
pbar.set_description(
|
|||
|
(
|
|||
|
f"loss: {loss.item():.4f};"
|
|||
|
)
|
|||
|
)
|
|||
|
if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
|
|||
|
with torch.no_grad():
|
|||
|
img_gen = g_ema(latent, c)
|
|||
|
|
|||
|
torchvision.utils.save_image(img_gen, f"results/stygan3Clip/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1))
|
|||
|
|
|||
|
if args.mode == "edit":
|
|||
|
final_result = torch.cat([img_orig, img_gen])
|
|||
|
else:
|
|||
|
final_result = img_gen
|
|||
|
|
|||
|
return final_result
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
parser.add_argument("--description", type=str, default="a person with purple hair", help="the text that guides the editing/generation")
|
|||
|
parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", help="pretrained StyleGAN2 weights")
|
|||
|
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
|
|||
|
parser.add_argument("--lr_rampup", type=float, default=0.05)
|
|||
|
parser.add_argument("--lr", type=float, default=0.1)
|
|||
|
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
|
|||
|
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], help="choose between edit an image an generate a free one")
|
|||
|
parser.add_argument("--l2_lambda", type=float, default=0.008, help="weight of the latent distance (used for editing only)")
|
|||
|
parser.add_argument("--id_lambda", type=float, default=0.000, help="weight of id loss (used for editing only)")
|
|||
|
parser.add_argument("--latent_path", type=str, default=None, help="starts the optimization from the given latent code if provided. Otherwose, starts from"
|
|||
|
"the mean latent in a free generation, and from a random one in editing. "
|
|||
|
"Expects a .pt format")
|
|||
|
parser.add_argument("--truncation", type=float, default=1, help="used only for the initial latent vector, and only when a latent code path is"
|
|||
|
"not provided")
|
|||
|
parser.add_argument('--work_in_stylespace', default=False, action='store_true')
|
|||
|
parser.add_argument("--save_intermediate_image_every", type=int, default=20, help="if > 0 then saves intermidate results during the optimization")
|
|||
|
parser.add_argument("--results_dir", type=str, default="results")
|
|||
|
parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str,
|
|||
|
help="Path to facial recognition network used in ID loss")
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
result_image = main(args)
|
|||
|
|
|||
|
torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1))
|
|||
|
|
|||
|
|