36 lines
2.4 KiB
Python
36 lines
2.4 KiB
Python
|
import torchvision
|
||
|
import argparse
|
||
|
from argparse import Namespace
|
||
|
from optimization.run_optimization import main
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--description", type=str, default="a person with purple hair",
|
||
|
help="the text that guides the editing/generation")
|
||
|
parser.add_argument("--ckpt", type=str, default="./pretrained_models/stylegan2-ffhq-config-f.pt",
|
||
|
help="pretrained StyleGAN2 weights")
|
||
|
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
|
||
|
parser.add_argument("--lr_rampup", type=float, default=0.05)
|
||
|
parser.add_argument("--lr", type=float, default=0.1)
|
||
|
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
|
||
|
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"],
|
||
|
help="choose between edit an image an generate a free one")
|
||
|
parser.add_argument("--l2_lambda", type=float, default=0.008,
|
||
|
help="weight of the latent distance (used for editing only)")
|
||
|
parser.add_argument("--latent_path", type=str, default="/home/ly/StyleCLIP-main/pretrained_models/latent_code/style3.pt",
|
||
|
help="starts the optimization from the given latent code if provided. Otherwise, starts from"
|
||
|
"the mean latent in a free generation, and from a random one in editing. "
|
||
|
"Expects a .pt format")
|
||
|
parser.add_argument("--truncation", type=float, default=0.7,
|
||
|
help="used only for the initial latent vector, and only when a latent code path is"
|
||
|
"not provided")
|
||
|
parser.add_argument("--save_intermediate_image_every", type=int, default=20,
|
||
|
help="if > 0 then saves intermidate results during the optimization")
|
||
|
parser.add_argument("--results_dir", type=str, default="results")
|
||
|
parser.add_argument('--work_in_stylespace', default=False, action='store_true', help="trains a mapper in S instead of W+")
|
||
|
parser.add_argument('--ir_se50_weights', default='pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss")
|
||
|
parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
|
||
|
|
||
|
args = vars(parser.parse_args())
|
||
|
result_image = main(Namespace(**args))
|
||
|
torchvision.utils.save_image(result_image.detach().cpu(), f"results/final_result.png", normalize=True, scale_each=True,
|
||
|
range=(-1, 1))
|