import os import time import torch import cv2 import numpy as np from backbones import iresnet50,iresnet18,iresnet100 def load_image(img_path): #img = cv2.imread(img_path) img = cv2.imdecode(np.fromfile(img_path,dtype=np.uint8),cv2.IMREAD_COLOR) img = img.transpose((2, 0, 1)) img = img[np.newaxis, :, :, :] img = np.array(img, dtype=np.float32) img -= 127.5 img /= 127.5 return img def findEuclideanDistance(source_representation, test_representation): euclidean_distance = source_representation - test_representation euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance)) euclidean_distance = np.sqrt(euclidean_distance) return euclidean_distance def l2_normalize(x): return x / np.sqrt(np.sum(np.multiply(x, x))) def load_npy(path): data = np.load(path,allow_pickle=True) data = data.item() return data def findmindistance(pred,threshold,k_v): distance = 10 most_like = "" for name in k_v.keys(): tmp = findEuclideanDistance(k_v[name],pred) if distance > tmp: distance = tmp most_like = name if distance < threshold: return most_like else: return -1 def findOne(img,model,k_v): with torch.no_grad(): start_time = time.time() pred = model(img) end_time = time.time() #print("predOne time: " + str(end_time - start_time)) pred = pred.numpy() name = findmindistance(l2_normalize(pred),threshold=1.20,k_v=k_v) if name != -1: return name else: return "unknown" def findAll(imglist,model,k_v): with torch.no_grad(): name_list = [] imglist = imglist.to(torch.device("cuda")) pred = model(imglist) pred = pred.cpu().numpy() for pr in pred: name = findmindistance(l2_normalize(pr),threshold=1.20,k_v=k_v) if name != -1: name_list.append(name) else: name_list.append("unknown") return name_list if __name__=='__main__': model = iresnet100() model.load_state_dict(torch.load("./model/backbone100.pth")) model.to(torch.device("cuda")) model.eval() pred_name = [] order_name = [] order_path = [] unknown = [] test_path = "./retinaface_test" name_list = os.listdir(test_path) for name in name_list: img_list = os.listdir(os.path.join(test_path,name)) for img in img_list: order_name.append(name) order_path.append(os.path.join(os.path.join(test_path,name),img)) order_img = np.zeros((len(order_path), 3, 112, 112), dtype=np.float32) for index,img_path in enumerate(order_path): order_img[index] = load_image(img_path) print(order_img.shape) # for name in order_path: # print(name) k_v = load_npy("retinaface_lfw_myalign.npy") start_time = time.time() order_img = torch.from_numpy(order_img) batch = 256 now = 0 number = len(order_img) #number = 1400 for i in range(number): unknown.append("unknown") while now < number: if now+batch < number: name = findAll(order_img[now:now+batch],model,k_v) else: name = findAll(order_img[now:number], model, k_v) now = now+batch for na in name: pred_name.append(na) print("batch"+str(now)) end_time = time.time() print("findAll time: " + str(end_time - start_time)) #print(len(pred_name)) right = 0 for i,name in enumerate(pred_name): if pred_name[i] == order_name[i]: right += 1 filed = 0 for i, name in enumerate(pred_name): if pred_name[i] == unknown[i]: filed += 1 error = 0 for i,name in enumerate(pred_name): if pred_name[i] != order_name[i]: error += 1 print(order_name[i]+" "+pred_name[i]+" "+order_path[i]) print("total:" + str(number)) print("right:" + str(right) + " rate:" + str(right / number)) print("filed:" + str(filed) + " rate:" + str(filed / number)) print("error:"+str(error-filed)+" rate:"+str((error-filed)/number))