liyin_code/models/stylegan3/run_optimization3.py

268 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse
import math
import os
import pickle
import torchvision
from torch import optim
from tqdm import tqdm
import torch
import clip
class CLIPLoss(torch.nn.Module):
def __init__(self, opts):
super(CLIPLoss, self).__init__()
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
self.upsample = torch.nn.Upsample(scale_factor=7)
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
def forward(self, image, text):
image = self.avg_pool(self.upsample(image))
similarity = 1 - self.model(image, text)[0] / 100
return similarity
from torch import nn
import sys
sys.path.append('/home/ly/StyleCLIP-main/models/facial_recognition')
from model_irse import Backbone
class IDLoss(nn.Module):
def __init__(self, opts):
super(IDLoss, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
self.facenet.cuda()
self.opts = opts
def extract_feats(self, x):
if x.shape[2] != 256:
x = self.pool(x)
x = x[:, :, 35:223, 32:220] # Crop interesting region
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats
def forward(self, y_hat, y):
n_samples = y.shape[0]
y_feats = self.extract_feats(y) # Otherwise use the feature from there
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
sim_improvement = 0
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
loss += 1 - diff_target
count += 1
return loss / count, sim_improvement / count
sys.path.append('/home/ly/StyleCLIP-main/mapper/training')
from train_utils import STYLESPACE_DIMENSIONS
from model_3 import Generator
from model_3 import SynthesisNetwork
from model_3 import SynthesisLayer
sys.path.append('/home/ly/StyleCLIP-main')
from utils import ensure_checkpoint_exists
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in list(range(1, len(STYLESPACE_DIMENSIONS), 3))]
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def main(args):
ensure_checkpoint_exists(args.ckpt)
# 把描述加载进clip预训练模型里面去
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
# print('text_input是 ', text_inputs)
#tokenizer clip分词的机制 依据规则
#以及词汇表的总量
'''
--description "a person with purple hair"
tensor([[49406, 320, 2533, 593, 5496, 2225, 49407, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
dtype=torch.int32)
--description "a person with red hair"
tensor([[49406, 320, 2533, 593, 736, 2225, 49407, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
dtype=torch.int32)
'''
os.makedirs(args.results_dir, exist_ok=True)
#改成stylegan3的输入
# with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f:
# G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
# z = torch.randn([1, G.z_dim]).cuda() # latent codes
# c = None # class labels (not used in this example)
# img = G(z, c) # NCHW, float32, dynamic range [-1, +1], no truncation
# g_ema = Generator(512, 0, 512,args.stylegan_size, 3) #512,0,512,1024,3
# with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f:
#stylegan3-r-ffhqu-1024x1024.pkl 生成图片的效果欠佳 别用
#stylegan3-t-ffhq-1024x1024.pkl 生成效果一般 loss值较好
#stylegan3-r-ffhq-1024x1024.pkl 折中
#stylegan3-t-ffhqu-1024x1024.pkl 生成图片可以 loss较差
with open('/home/ly/StyleCLIP-main/pretrained_models/stylegan3-t-ffhq-1024x1024.pkl', 'rb') as f: #stylespace_dimensions [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 32, 32]
# new_p = pickle.load(f)
# print(new_p)
# print("new_p")
# print(new_p.keys())
# G_ema.load_state_dict(pickle.load(f)['G_ema'].cuda(), strict=False) 这种方式模型加载不进来
g_ema = pickle.load(f)['G_ema'].cuda() # torch.nn.Module 这种方式推演三百步的图片平均要4分钟
z = torch.randn([1, g_ema.z_dim]).cuda() # latent codes
c = None # class labels (not used in this example)
#g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
# 将模型对象设置为评估模式
g_ema.eval()
#更改cuda卡号
g_ema = g_ema.cuda()
# device = torch.cuda.current_device()
# print('cuda:',device)
mean_latent = torch.randn([1, g_ema.z_dim]).cuda()
torch.save(mean_latent,'/home/ly/StyleCLIP-main/pretrained_models/latent_code/style3.pt')
# print('mean_latent: ', mean_latent)
if args.latent_path:
latent_code_init = torch.load(args.latent_path).cuda()
# elif args.mode == "edit":
# latent_code_init_not_trunc = torch.randn(1, 512).cuda()
# with torch.no_grad():
# _, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True,
# truncation=args.truncation, truncation_latent=mean_latent)
else:
# latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1) #在维度1上重复18次
latent_code_init = mean_latent.detach().clone()
# def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
with torch.no_grad():
print("mean_latent ", mean_latent.shape)
# img_orig, _ = g_ema([latent_code_init], c, input_is_latent=True, randomize_noise=False)
img_orig = g_ema(latent_code_init, c)
if args.work_in_stylespace:
with torch.no_grad():
_, _, latent_code_init = g_ema([latent_code_init], input_is_latent=True, return_latents=True)
latent = [s.detach().clone() for s in latent_code_init]
for c, s in enumerate(latent):
if c in STYLESPACE_INDICES_WITHOUT_TORGB:
s.requires_grad = True
else:
latent = latent_code_init.detach().clone()
latent.requires_grad = True
clip_loss = CLIPLoss(args)
id_loss = IDLoss(args)
if args.work_in_stylespace:
optimizer = optim.Adam(latent, lr=args.lr)
else:
optimizer = optim.Adam([latent], lr=args.lr)
pbar = tqdm(range(args.step))
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]["lr"] = lr
img_gen = g_ema(latent,c)
c_loss = clip_loss(img_gen, text_inputs)
if args.id_lambda > 0:
#身份损失
i_loss = id_loss(img_gen, img_orig)[0]
else:
i_loss = 0
if args.mode == "edit":
if args.work_in_stylespace:
l2_loss = sum([((latent_code_init[c] - latent[c]) ** 2).sum() for c in range(len(latent_code_init))])
else:
#与潜在空间的L2距离
l2_loss = ((latent_code_init - latent) ** 2).sum()
loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
else:
loss = c_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(
(
f"loss: {loss.item():.4f};"
)
)
if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
with torch.no_grad():
img_gen = g_ema(latent, c)
torchvision.utils.save_image(img_gen, f"results/stygan3Clip/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1))
if args.mode == "edit":
final_result = torch.cat([img_orig, img_gen])
else:
final_result = img_gen
return final_result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--description", type=str, default="a person with purple hair", help="the text that guides the editing/generation")
parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", help="pretrained StyleGAN2 weights")
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
parser.add_argument("--lr_rampup", type=float, default=0.05)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], help="choose between edit an image an generate a free one")
parser.add_argument("--l2_lambda", type=float, default=0.008, help="weight of the latent distance (used for editing only)")
parser.add_argument("--id_lambda", type=float, default=0.000, help="weight of id loss (used for editing only)")
parser.add_argument("--latent_path", type=str, default=None, help="starts the optimization from the given latent code if provided. Otherwose, starts from"
"the mean latent in a free generation, and from a random one in editing. "
"Expects a .pt format")
parser.add_argument("--truncation", type=float, default=1, help="used only for the initial latent vector, and only when a latent code path is"
"not provided")
parser.add_argument('--work_in_stylespace', default=False, action='store_true')
parser.add_argument("--save_intermediate_image_every", type=int, default=20, help="if > 0 then saves intermidate results during the optimization")
parser.add_argument("--results_dir", type=str, default="results")
parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str,
help="Path to facial recognition network used in ID loss")
args = parser.parse_args()
result_image = main(args)
torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1))