#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Jul 19 21:03:58 2021 @author: wuzongze """ import sys import copy import os from time import perf_counter import click import imageio import numpy as np import PIL.Image import torch import torch.nn.functional as F from PIL import Image import dnnlib import legacy import pickle from visualizer import HtmlPageVisualizer from torch_utils import misc import types from training.networks import SynthesisNetwork,SynthesisBlock,SynthesisLayer,ToRGBLayer def change_style_code(codes, layer, channel, step): codes[layer][:, channel] += step return codes def Vis(bname,suffix,out,rownames=None,colnames=None,save_path=None,viz_size=256): if save_path is None: save_path='./html/' num_images=out.shape[0] step=out.shape[1] if colnames is None: colnames=[f'Step {i:02d}' for i in range(1, step + 1)] if rownames is None: rownames=[str(i) for i in range(num_images)] visualizer = HtmlPageVisualizer( num_rows=num_images, num_cols=step + 1, viz_size=viz_size) visualizer.set_headers( ['Name'] +colnames) for i in range(num_images): visualizer.set_cell(i, 0, text=rownames[i]) for i in range(num_images): for k in range(step): image=out[i,k,:,:,:] visualizer.set_cell(i, 1+k, image=image) visualizer.save(save_path+bname+'_'+suffix+'.html') def LoadModel(network_pkl,device): with dnnlib.util.open_url(network_pkl) as fp: G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore G.synthesis.forward=types.MethodType(SynthesisNetwork.forward,G.synthesis) G.synthesis.W2S=types.MethodType(SynthesisNetwork.W2S,G.synthesis) for res in G.synthesis.block_resolutions: block = getattr(G.synthesis, f'b{res}') # print(block) block.forward=types.MethodType(SynthesisBlock.forward,block) if res!=4: layer=block.conv0 layer.forward=types.MethodType(SynthesisLayer.forward,layer) layer.name='conv0_resolution_'+str(res) layer=block.conv1 layer.forward=types.MethodType(SynthesisLayer.forward,layer) layer.name='conv1_resolution_'+str(res) layer=block.torgb layer.forward=types.MethodType(ToRGBLayer.forward,layer) layer.name='toRGB_resolution_'+str(res) return G def S2List(encoded_styles): all_s=[] for name in encoded_styles.keys(): tmp=encoded_styles[name].cpu().numpy() all_s.append(tmp) return all_s class Manipulator(): def __init__(self,dataset_name='ffhq'): self.alpha=[0] #manipulation strength self.num_images=10 self.img_index=0 #which image to start # self.viz_size=256 self.manipulate_layers=None #which layer to manipulate, list self.truncation_psi=0.7 self.truncation_cutoff=8 # self.G=LoadModel(self.model_path,self.model_name) self.LoadModel=LoadModel self.Vis=Vis self.S2List=S2List fmaps=[512, 512, 512, 512, 512, 256, 128, 64, 32] self.fmaps=np.repeat(fmaps,3) def GetSName(self): s_names=[] for res in self.G.synthesis.block_resolutions: if res==4: tmp=f'conv1_resolution_{res}' s_names.append(tmp) tmp=f'toRGB_resolution_{res}' s_names.append(tmp) else: tmp=f'conv0_resolution_{res}' s_names.append(tmp) tmp=f'conv1_resolution_{res}' s_names.append(tmp) tmp=f'toRGB_resolution_{res}' s_names.append(tmp) return s_names def SL2D(self,tmp_code): encoded_styles={} for i in range(len(self.s_names)): encoded_styles[self.s_names[i]]=torch.from_numpy(tmp_code[i]).to(self.device) return encoded_styles def GenerateS(self,num_img=100): seed=5 with torch.no_grad(): z = torch.from_numpy(np.random.RandomState(seed).randn(num_img, self.G.z_dim)).to(self.device) ws = self.G.mapping(z=z,c=None,truncation_psi=self.truncation_psi,truncation_cutoff=self.truncation_cutoff) encoded_styles=self.G.synthesis.W2S(ws) # encoded_styles=encoded_styles.cpu().numpy() self.dlatents=S2List(encoded_styles) def GenerateImg(self,codes): num_images,step=codes[0].shape[:2] out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8') for i in range(num_images): for k in range(step): tmp_code=[] for m in range(len(self.s_names)): tmp=codes[m][i,k][None,:] tmp_code.append(tmp) encoded_styles=self.SL2D(tmp_code) with torch.no_grad(): img = self.G.synthesis(None, encoded_styles=encoded_styles,noise_mode='const') img = (img + 1) * (255/2) img = img.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() if img.shape[1]==img.shape[0]: out[i,k,:,:,:]=img else: tmp=img.shape[1] tmp1=int((img.shape[0]-tmp)/2) out[i,k,:,tmp1:tmp1+tmp,:]=img return out def ShowImg(self,num_img=10): codes=[] for i in range(len(self.dlatents)): # print(i) tmp=self.dlatents[i][:num_img,None,:] codes.append(tmp) out=self.GenerateImg(codes) return out def SetGParameters(self): self.num_layers=self.G.synthesis.num_ws self.img_size=self.G.synthesis.img_resolution self.s_names=self.GetSName() self.img_size=self.G.synthesis.block_resolutions[-1] self.mindexs=[0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21,23,24] def MSCode(self,dlatent_tmp,boundary_tmp): step=len(self.alpha) dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp] dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] # (10, 7, 512) l=np.array(self.alpha) l=l.reshape( [step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)]) if type(self.manipulate_layers)==int: tmp=[self.manipulate_layers] elif type(self.manipulate_layers)==list: tmp=self.manipulate_layers elif self.manipulate_layers is None: tmp=np.arange(len(boundary_tmp)) else: raise ValueError('manipulate_layers is wrong') for i in tmp: dlatent_tmp2[i]+=l*boundary_tmp[i] codes=[] for i in range(len(dlatent_tmp2)): tmp=list(dlatent_tmp[i].shape) tmp.insert(1,step) codes.append(dlatent_tmp2[i].reshape(tmp)) return codes def EditOne(self,bname,dlatent_tmp=None): if dlatent_tmp==None: dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] boundary_tmp=[] for i in range(len(self.boundary)): tmp=self.boundary[i] if len(tmp)<=bname: boundary_tmp.append([]) else: boundary_tmp.append(tmp[bname]) codes=self.MSCode(dlatent_tmp,boundary_tmp) out=self.GenerateImg(codes) return codes,out def EditOneC(self,cindex,dlatent_tmp=None): if dlatent_tmp==None: dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] boundary_tmp=[[] for i in range(len(self.dlatents))] #'only manipulate 1 layer and one channel' assert len(self.manipulate_layers)==1 ml=self.manipulate_layers[0] tmp=dlatent_tmp[ml].shape[1] #ada tmp1=np.zeros(tmp) tmp1[cindex]=self.code_std[ml][cindex] #1 boundary_tmp[ml]=tmp1 codes=self.MSCode(dlatent_tmp,boundary_tmp) out=self.GenerateImg(codes) return codes,out def GetFindex(self,lindex,cindex,ignore_RGB=False): if ignore_RGB: tmp=np.array(self.mindexs)0] lindex=len(tmp) if lindex==0: cindex=tmp_index else: cindex=tmp[-1] if cindex ==self.fmaps[lindex]: cindex=0 lindex+=1 # print(completeness.index[i],completeness.iloc[i,:].values,lindex,cindex) l_p.append([lindex,cindex]) l_p=np.array(l_p) return l_p def GetLCIndex2(self,findex): #input findex without ToRGB fmaps_o=copy.copy(self.fmaps) mindexs=[0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21,23,24] self.fmaps=fmaps_o[mindexs] l_p=self.GetLCIndex(findex) l=l_p[:,0] l2=np.array(mindexs)[l] l_p[:,0]=l2 self.fmaps=fmaps_o return l_p def GetCodeMS(self): m=[] std=[] for i in range(len(self.dlatents)): tmp= self.dlatents[i] tmp_mean=tmp.mean(axis=0) tmp_std=tmp.std(axis=0) m.append(tmp_mean) std.append(tmp_std) self.code_mean=m self.code_std=std # return m,std #%% if __name__ == "__main__": network_pkl='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/model/stylegan2-ffhq-config-f.pkl' device = torch.device('cuda') M=Manipulator() M.device=device G=M.LoadModel(network_pkl,device) M.G=G M.SetGParameters() num_img=100_000 M.GenerateS(num_img=num_img) M.GetCodeMS() np.set_printoptions(suppress=True) #%% M.alpha=[24,16,8,0,-8,-16,-24] M.step=len(M.alpha) M.img_index=0 M.num_images=10 lindex,bname=6,501 # M. M.manipulate_layers=[lindex] codes,out=M.EditOneC(bname) #dlatent_tmp tmp=str(M.manipulate_layers)+'_'+str(bname) M.Vis(tmp,'c',out)