liyin_code/optimization/run_optimization.py

182 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse
import math
import os
import torch
import torchvision
from torch import optim
from tqdm import tqdm
from criteria.clip_loss import CLIPLoss
from criteria.id_loss import IDLoss
from mapper.training.train_utils import STYLESPACE_DIMENSIONS
from models.stylegan2.model import Generator
import clip
from utils import ensure_checkpoint_exists
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in list(range(1, len(STYLESPACE_DIMENSIONS), 3))]
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def main(args):
ensure_checkpoint_exists(args.ckpt)
# 把描述加载进clip预训练模型里面去
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
# print('text_input是 ', text_inputs)
'''
--description "a person with purple hair"
tensor([[49406, 320, 2533, 593, 5496, 2225, 49407, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
dtype=torch.int32)
--description "a person with red hair"
tensor([[49406, 320, 2533, 593, 736, 2225, 49407, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
dtype=torch.int32)
'''
os.makedirs(args.results_dir, exist_ok=True)
g_ema = Generator(args.stylegan_size, 512, 8)
g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
# 将模型对象设置为评估模式
g_ema.eval()
#更改cuda卡号
g_ema = g_ema.cuda()
# device = torch.cuda.current_device()
# print('cuda:',device)
mean_latent = g_ema.mean_latent(4096)
# print('mean_latent: ', mean_latent.shape ) #[1,512]
if args.latent_path:
latent_code_init = torch.load(args.latent_path).cuda()
with torch.no_grad():
_, latent_code_init, _ = g_ema([latent_code_init], return_latents=True,
truncation=args.truncation, truncation_latent=mean_latent)
elif args.mode == "edit":
latent_code_init_not_trunc = torch.randn(1, 512).cuda()
with torch.no_grad():
_, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True,
truncation=args.truncation, truncation_latent=mean_latent)
else:
latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)
print(latent_code_init) #在维度1上重复18次 torch.Size([1, 18, 512])
with torch.no_grad():
img_orig, _ = g_ema([latent_code_init], input_is_latent=True, randomize_noise=False)
if args.work_in_stylespace:
with torch.no_grad():
_, _, latent_code_init = g_ema([latent_code_init], input_is_latent=True, return_latents=True)
latent = [s.detach().clone() for s in latent_code_init]
for c, s in enumerate(latent):
if c in STYLESPACE_INDICES_WITHOUT_TORGB:
s.requires_grad = True
else:
latent = latent_code_init.detach().clone()
latent.requires_grad = True
clip_loss = CLIPLoss(args)
id_loss = IDLoss(args)
if args.work_in_stylespace:
optimizer = optim.Adam(latent, lr=args.lr)
else:
optimizer = optim.Adam([latent], lr=args.lr)
pbar = tqdm(range(args.step))
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]["lr"] = lr
img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=args.work_in_stylespace)
c_loss = clip_loss(img_gen, text_inputs)
if args.id_lambda > 0:
i_loss = id_loss(img_gen, img_orig)[0]
else:
i_loss = 0
if args.mode == "edit":
if args.work_in_stylespace:
l2_loss = sum([((latent_code_init[c] - latent[c]) ** 2).sum() for c in range(len(latent_code_init))])
else:
l2_loss = ((latent_code_init - latent) ** 2).sum()
loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
else:
loss = c_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(
(
f"loss: {loss.item():.4f};"
)
)
if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
with torch.no_grad():
img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=args.work_in_stylespace)
torchvision.utils.save_image(img_gen, f"results/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1))
if args.mode == "edit":
final_result = torch.cat([img_orig, img_gen])
else:
final_result = img_gen
return final_result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--description", type=str, default="a person with purple hair", help="the text that guides the editing/generation")
parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", help="pretrained StyleGAN2 weights")
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
parser.add_argument("--lr_rampup", type=float, default=0.05)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], help="choose between edit an image an generate a free one")
parser.add_argument("--l2_lambda", type=float, default=0.008, help="weight of the latent distance (used for editing only)")
parser.add_argument("--id_lambda", type=float, default=0.000, help="weight of id loss (used for editing only)")
parser.add_argument("--latent_path", type=str, default=None, help="starts the optimization from the given latent code if provided. Otherwose, starts from"
"the mean latent in a free generation, and from a random one in editing. "
"Expects a .pt format")
parser.add_argument("--truncation", type=float, default=0.7, help="used only for the initial latent vector, and only when a latent code path is"
"not provided")
parser.add_argument('--work_in_stylespace', default=False, action='store_true')
parser.add_argument("--save_intermediate_image_every", type=int, default=20, help="if > 0 then saves intermidate results during the optimization")
parser.add_argument("--results_dir", type=str, default="results")
parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str,
help="Path to facial recognition network used in ID loss")
args = parser.parse_args()
result_image = main(args)
torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1))