Face/accuracy.py

133 lines
4.0 KiB
Python
Raw Normal View History

2024-07-29 11:24:25 +08:00
import os
import time
import torch
import cv2
import numpy as np
from backbones import iresnet50,iresnet18,iresnet100
def load_image(img_path):
#img = cv2.imread(img_path)
img = cv2.imdecode(np.fromfile(img_path,dtype=np.uint8),cv2.IMREAD_COLOR)
img = img.transpose((2, 0, 1))
img = img[np.newaxis, :, :, :]
img = np.array(img, dtype=np.float32)
img -= 127.5
img /= 127.5
return img
def findEuclideanDistance(source_representation, test_representation):
euclidean_distance = source_representation - test_representation
euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance))
euclidean_distance = np.sqrt(euclidean_distance)
return euclidean_distance
def l2_normalize(x):
return x / np.sqrt(np.sum(np.multiply(x, x)))
def load_npy(path):
data = np.load(path,allow_pickle=True)
data = data.item()
return data
def findmindistance(pred,threshold,k_v):
distance = 10
most_like = ""
for name in k_v.keys():
tmp = findEuclideanDistance(k_v[name],pred)
if distance > tmp:
distance = tmp
most_like = name
if distance < threshold:
return most_like
else:
return -1
def findOne(img,model,k_v):
with torch.no_grad():
start_time = time.time()
pred = model(img)
end_time = time.time()
#print("predOne time: " + str(end_time - start_time))
pred = pred.numpy()
name = findmindistance(l2_normalize(pred),threshold=1.20,k_v=k_v)
if name != -1:
return name
else:
return "unknown"
def findAll(imglist,model,k_v):
with torch.no_grad():
name_list = []
pred = model(imglist)
pred = pred.numpy()
for pr in pred:
name = findmindistance(l2_normalize(pr),threshold=1.20,k_v=k_v)
if name != -1:
name_list.append(name)
else:
name_list.append("unknown")
return name_list
if __name__=='__main__':
model = iresnet100()
model.load_state_dict(torch.load("./model/backbone100.pth", map_location="cpu"))
model.eval()
pred_name = []
order_name = []
order_path = []
unknown = []
test_path = "D:\Download\out\cfp_test"
name_list = os.listdir(test_path)
for name in name_list:
img_list = os.listdir(os.path.join(test_path,name))
for img in img_list:
order_name.append(name)
order_path.append(os.path.join(os.path.join(test_path,name),img))
order_img = np.zeros((len(order_path), 3, 112, 112), dtype=np.float32)
for index,img_path in enumerate(order_path):
order_img[index] = load_image(img_path)
print(order_img.shape)
# for name in order_path:
# print(name)
k_v = load_npy("cfp.npy")
start_time = time.time()
order_img = torch.from_numpy(order_img)
batch = 256
now = 0
number = len(order_img)
#number = 1400
for i in range(number):
unknown.append("unknown")
while now < number:
if now+batch < number:
name = findAll(order_img[now:now+batch],model,k_v)
else:
name = findAll(order_img[now:number], model, k_v)
now = now+batch
for na in name:
pred_name.append(na)
print("batch"+str(now))
end_time = time.time()
print("findAll time: " + str(end_time - start_time))
#print(len(pred_name))
right = 0
for i,name in enumerate(pred_name):
if pred_name[i] == order_name[i]:
right += 1
filed = 0
for i, name in enumerate(pred_name):
if pred_name[i] == unknown[i]:
filed += 1
error = 0
for i,name in enumerate(pred_name):
if pred_name[i] != order_name[i]:
error += 1
print(order_name[i]+" "+pred_name[i]+" "+order_path[i])
print("total:" + str(number))
print("right:" + str(right) + " rate:" + str(right / number))
print("filed:" + str(filed) + " rate:" + str(filed / number))
print("error:"+str(error-filed)+" rate:"+str((error-filed)/number))