import argparse import math import os import pickle import torchvision from torch import optim from tqdm import tqdm import torch import clip class CLIPLoss(torch.nn.Module): def __init__(self, opts): super(CLIPLoss, self).__init__() self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") self.upsample = torch.nn.Upsample(scale_factor=7) self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) def forward(self, image, text): image = self.avg_pool(self.upsample(image)) similarity = 1 - self.model(image, text)[0] / 100 return similarity from torch import nn import sys sys.path.append('/home/ly/StyleCLIP-main/models/facial_recognition') from model_irse import Backbone class IDLoss(nn.Module): def __init__(self, opts): super(IDLoss, self).__init__() print('Loading ResNet ArcFace') self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) self.facenet.eval() self.facenet.cuda() self.opts = opts def extract_feats(self, x): if x.shape[2] != 256: x = self.pool(x) x = x[:, :, 35:223, 32:220] # Crop interesting region x = self.face_pool(x) x_feats = self.facenet(x) return x_feats def forward(self, y_hat, y): n_samples = y.shape[0] y_feats = self.extract_feats(y) # Otherwise use the feature from there y_hat_feats = self.extract_feats(y_hat) y_feats = y_feats.detach() loss = 0 sim_improvement = 0 count = 0 for i in range(n_samples): diff_target = y_hat_feats[i].dot(y_feats[i]) loss += 1 - diff_target count += 1 return loss / count, sim_improvement / count sys.path.append('/home/ly/StyleCLIP-main/mapper/training') from train_utils import STYLESPACE_DIMENSIONS from model_3 import Generator from model_3 import SynthesisNetwork from model_3 import SynthesisLayer sys.path.append('/home/ly/StyleCLIP-main') from utils import ensure_checkpoint_exists STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in list(range(1, len(STYLESPACE_DIMENSIONS), 3))] def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): lr_ramp = min(1, (1 - t) / rampdown) lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) lr_ramp = lr_ramp * min(1, t / rampup) return initial_lr * lr_ramp def main(args): ensure_checkpoint_exists(args.ckpt) # 把描述加载进clip预训练模型里面去 text_inputs = torch.cat([clip.tokenize(args.description)]).cuda() # print('text_input是: ', text_inputs) #tokenizer clip分词的机制 依据规则 #以及词汇表的总量 ''' --description "a person with purple hair" tensor([[49406, 320, 2533, 593, 5496, 2225, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0', dtype=torch.int32) --description "a person with red hair" tensor([[49406, 320, 2533, 593, 736, 2225, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0', dtype=torch.int32) ''' os.makedirs(args.results_dir, exist_ok=True) #改成stylegan3的输入 # with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f: # G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module # z = torch.randn([1, G.z_dim]).cuda() # latent codes # c = None # class labels (not used in this example) # img = G(z, c) # NCHW, float32, dynamic range [-1, +1], no truncation # g_ema = Generator(512, 0, 512,args.stylegan_size, 3) #512,0,512,1024,3 # with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f: #stylegan3-r-ffhqu-1024x1024.pkl 生成图片的效果欠佳 别用 #stylegan3-t-ffhq-1024x1024.pkl 生成效果一般 loss值较好 #stylegan3-r-ffhq-1024x1024.pkl 折中 #stylegan3-t-ffhqu-1024x1024.pkl 生成图片可以 loss较差 with open('/home/ly/StyleCLIP-main/pretrained_models/stylegan3-t-ffhq-1024x1024.pkl', 'rb') as f: #stylespace_dimensions [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 32, 32] # new_p = pickle.load(f) # print(new_p) # print("new_p") # print(new_p.keys()) # G_ema.load_state_dict(pickle.load(f)['G_ema'].cuda(), strict=False) 这种方式模型加载不进来 g_ema = pickle.load(f)['G_ema'].cuda() # torch.nn.Module 这种方式推演三百步的图片平均要4分钟 z = torch.randn([1, g_ema.z_dim]).cuda() # latent codes c = None # class labels (not used in this example) #g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) # 将模型对象设置为评估模式 g_ema.eval() #更改cuda卡号 g_ema = g_ema.cuda() # device = torch.cuda.current_device() # print('cuda:',device) mean_latent = torch.randn([1, g_ema.z_dim]).cuda() torch.save(mean_latent,'/home/ly/StyleCLIP-main/pretrained_models/latent_code/style3.pt') # print('mean_latent: ', mean_latent) if args.latent_path: latent_code_init = torch.load(args.latent_path).cuda() # elif args.mode == "edit": # latent_code_init_not_trunc = torch.randn(1, 512).cuda() # with torch.no_grad(): # _, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True, # truncation=args.truncation, truncation_latent=mean_latent) else: # latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1) #在维度1上重复18次 latent_code_init = mean_latent.detach().clone() # def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): with torch.no_grad(): print("mean_latent ", mean_latent.shape) # img_orig, _ = g_ema([latent_code_init], c, input_is_latent=True, randomize_noise=False) img_orig = g_ema(latent_code_init, c) if args.work_in_stylespace: with torch.no_grad(): _, _, latent_code_init = g_ema([latent_code_init], input_is_latent=True, return_latents=True) latent = [s.detach().clone() for s in latent_code_init] for c, s in enumerate(latent): if c in STYLESPACE_INDICES_WITHOUT_TORGB: s.requires_grad = True else: latent = latent_code_init.detach().clone() latent.requires_grad = True clip_loss = CLIPLoss(args) id_loss = IDLoss(args) if args.work_in_stylespace: optimizer = optim.Adam(latent, lr=args.lr) else: optimizer = optim.Adam([latent], lr=args.lr) pbar = tqdm(range(args.step)) for i in pbar: t = i / args.step lr = get_lr(t, args.lr) optimizer.param_groups[0]["lr"] = lr img_gen = g_ema(latent,c) c_loss = clip_loss(img_gen, text_inputs) if args.id_lambda > 0: #身份损失 i_loss = id_loss(img_gen, img_orig)[0] else: i_loss = 0 if args.mode == "edit": if args.work_in_stylespace: l2_loss = sum([((latent_code_init[c] - latent[c]) ** 2).sum() for c in range(len(latent_code_init))]) else: #与潜在空间的L2距离 l2_loss = ((latent_code_init - latent) ** 2).sum() loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss else: loss = c_loss optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description( ( f"loss: {loss.item():.4f};" ) ) if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0: with torch.no_grad(): img_gen = g_ema(latent, c) torchvision.utils.save_image(img_gen, f"results/stygan3Clip/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1)) if args.mode == "edit": final_result = torch.cat([img_orig, img_gen]) else: final_result = img_gen return final_result if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--description", type=str, default="a person with purple hair", help="the text that guides the editing/generation") parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", help="pretrained StyleGAN2 weights") parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution") parser.add_argument("--lr_rampup", type=float, default=0.05) parser.add_argument("--lr", type=float, default=0.1) parser.add_argument("--step", type=int, default=300, help="number of optimization steps") parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], help="choose between edit an image an generate a free one") parser.add_argument("--l2_lambda", type=float, default=0.008, help="weight of the latent distance (used for editing only)") parser.add_argument("--id_lambda", type=float, default=0.000, help="weight of id loss (used for editing only)") parser.add_argument("--latent_path", type=str, default=None, help="starts the optimization from the given latent code if provided. Otherwose, starts from" "the mean latent in a free generation, and from a random one in editing. " "Expects a .pt format") parser.add_argument("--truncation", type=float, default=1, help="used only for the initial latent vector, and only when a latent code path is" "not provided") parser.add_argument('--work_in_stylespace', default=False, action='store_true') parser.add_argument("--save_intermediate_image_every", type=int, default=20, help="if > 0 then saves intermidate results during the optimization") parser.add_argument("--results_dir", type=str, default="results") parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss") args = parser.parse_args() result_image = main(args) torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1))