import os import time import re import torch import cv2 import numpy as np from backbones import iresnet50,iresnet18,iresnet100 def load_image(img_path): #img = cv2.imread(img_path) img = cv2.imdecode(np.fromfile(img_path,dtype=np.uint8),cv2.IMREAD_COLOR) img = img.transpose((2, 0, 1)) img = img[np.newaxis, :, :, :] img = np.array(img, dtype=np.float32) img -= 127.5 img /= 127.5 return img def findEuclideanDistance(source_representation, test_representation): euclidean_distance = source_representation - test_representation euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance)) euclidean_distance = np.sqrt(euclidean_distance) return euclidean_distance def findCosineDistance(source_representation, test_representation): a = np.matmul(np.transpose(source_representation), test_representation) b = np.sum(np.multiply(source_representation, source_representation)) c = np.sum(np.multiply(test_representation, test_representation)) return 1 - (a / (np.sqrt(b) * np.sqrt(c))) def l2_normalize(x): return x / np.sqrt(np.sum(np.multiply(x, x))) def cosin_metric(x1, x2): return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2)) def load_npy(path): data = np.load(path,allow_pickle=True) data = data.item() return data def create_database(path,model,database_path): name_list = os.listdir(path) k_v = {} if os.path.exists(database_path): k_v = np.load(database_path, allow_pickle=True) k_v = k_v.item() for name in name_list: img_path = os.listdir(os.path.join(path,name)) for img_name in img_path[:1]: img = load_image(os.path.join(path,name,img_name)) img = torch.from_numpy(img) with torch.no_grad(): pred = model(img) pred = pred.numpy() k_v[name] = l2_normalize(pred) np.save(database_path, k_v) def create_database_batch(path,model,database_path): name_list = os.listdir(path) k_v = {} if os.path.exists(database_path): k_v = np.load(database_path, allow_pickle=True) k_v = k_v.item() batch = 256 order_name = [] order_path = [] emb_list = [] for name in name_list: img_path = os.listdir(os.path.join(path,name)) for img_name in img_path[:1]: order_name.append(name) order_path.append(os.path.join(path,name,img_name)) order_img = np.zeros((len(order_path), 3, 112, 112), dtype=np.float32) for index, img_path in enumerate(order_path): order_img[index] = load_image(img_path) print(order_img.shape) order_img = torch.from_numpy(order_img) now = 0 number = len(order_img) with torch.no_grad(): while now < number: if now + batch < number: emb = model(order_img[now:now+batch]) else: emb = model(order_img[now:]) now = now + batch for em in emb: emb_list.append(em) print("batch"+str(now)) for i, emb in enumerate(emb_list): k_v[order_name[i]] = l2_normalize(emb.numpy()) np.save(database_path, k_v) def add_one(img,model,name,database_path): img = torch.from_numpy(img) with torch.no_grad(): pred = model(img) pred = pred.numpy() k_v = {} if os.path.exists(database_path): k_v = np.load(database_path, allow_pickle=True) k_v = k_v.item() k_v[name] = l2_normalize(pred) np.save(database_path, k_v) def findmindistance(pred,threshold,k_v): distance = 10 most_like = "" for name in k_v.keys(): tmp = findEuclideanDistance(k_v[name],pred) if distance > tmp: distance = tmp most_like = name if distance < threshold: return most_like else: return -1 def findOne(img,model,k_v): with torch.no_grad(): start_time = time.time() pred = model(img) end_time = time.time() #print("predOne time: " + str(end_time - start_time)) pred = pred.numpy() name = findmindistance(l2_normalize(pred),threshold=1.20,k_v=k_v) if name != -1: return name else: return "unknown" def findAll(imglist,model,k_v): with torch.no_grad(): name_list = [] pred = model(imglist) pred = pred.numpy() for pr in pred: name = findmindistance(l2_normalize(pr),threshold=1.20,k_v=k_v) if name != -1: name_list.append(name) else: name_list.append("unknown") return name_list if __name__=='__main__': model = iresnet100() model.load_state_dict(torch.load("./model/backbone100.pth", map_location="cpu")) model.eval() #img = load_image(r"D:\Download\out\facedatabase\man.jpg") #img = load_image(r"D:\Download\out\facedatabase\man6.jpg") # img = load_image(r"D:\Download\out\alig_students\student.jpg") # print(img.shape) # # k_v = load_npy("./Database/student.npy") # start_time = time.time() # img = torch.from_numpy(img) # name = findOne(img,model,k_v) # mo = r'[\u4e00-\u9fa5]*' # name = re.match(mo,name) # print(name.group(0)) # end_time = time.time() # print("findOne time: " + str(end_time - start_time)) #create_database_batch(r"D:\Download\out\alig_students",model,"./Database/student.npy") create_database_batch(r"D:\Download\out\cfp_database", model, "cfp.npy") #add_one(img,model,"Arminio_Fraga","centerface_lfw.npy")