133 lines
4.0 KiB
Python
133 lines
4.0 KiB
Python
|
import os
|
||
|
import time
|
||
|
|
||
|
import torch
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
from backbones import iresnet50,iresnet18,iresnet100
|
||
|
|
||
|
def load_image(img_path):
|
||
|
#img = cv2.imread(img_path)
|
||
|
img = cv2.imdecode(np.fromfile(img_path,dtype=np.uint8),cv2.IMREAD_COLOR)
|
||
|
img = img.transpose((2, 0, 1))
|
||
|
img = img[np.newaxis, :, :, :]
|
||
|
img = np.array(img, dtype=np.float32)
|
||
|
img -= 127.5
|
||
|
img /= 127.5
|
||
|
return img
|
||
|
|
||
|
def findEuclideanDistance(source_representation, test_representation):
|
||
|
euclidean_distance = source_representation - test_representation
|
||
|
euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance))
|
||
|
euclidean_distance = np.sqrt(euclidean_distance)
|
||
|
return euclidean_distance
|
||
|
|
||
|
def l2_normalize(x):
|
||
|
return x / np.sqrt(np.sum(np.multiply(x, x)))
|
||
|
|
||
|
def load_npy(path):
|
||
|
data = np.load(path,allow_pickle=True)
|
||
|
data = data.item()
|
||
|
return data
|
||
|
|
||
|
def findmindistance(pred,threshold,k_v):
|
||
|
distance = 10
|
||
|
most_like = ""
|
||
|
for name in k_v.keys():
|
||
|
tmp = findEuclideanDistance(k_v[name],pred)
|
||
|
if distance > tmp:
|
||
|
distance = tmp
|
||
|
most_like = name
|
||
|
if distance < threshold:
|
||
|
return most_like
|
||
|
else:
|
||
|
return -1
|
||
|
|
||
|
def findOne(img,model,k_v):
|
||
|
with torch.no_grad():
|
||
|
start_time = time.time()
|
||
|
pred = model(img)
|
||
|
end_time = time.time()
|
||
|
#print("predOne time: " + str(end_time - start_time))
|
||
|
pred = pred.numpy()
|
||
|
name = findmindistance(l2_normalize(pred),threshold=1.20,k_v=k_v)
|
||
|
if name != -1:
|
||
|
return name
|
||
|
else:
|
||
|
return "unknown"
|
||
|
def findAll(imglist,model,k_v):
|
||
|
with torch.no_grad():
|
||
|
name_list = []
|
||
|
pred = model(imglist)
|
||
|
pred = pred.numpy()
|
||
|
for pr in pred:
|
||
|
name = findmindistance(l2_normalize(pr),threshold=1.20,k_v=k_v)
|
||
|
if name != -1:
|
||
|
name_list.append(name)
|
||
|
else:
|
||
|
name_list.append("unknown")
|
||
|
return name_list
|
||
|
|
||
|
if __name__=='__main__':
|
||
|
model = iresnet100()
|
||
|
model.load_state_dict(torch.load("./model/backbone100.pth", map_location="cpu"))
|
||
|
model.eval()
|
||
|
pred_name = []
|
||
|
order_name = []
|
||
|
order_path = []
|
||
|
unknown = []
|
||
|
test_path = "D:\Download\out\cfp_test"
|
||
|
name_list = os.listdir(test_path)
|
||
|
for name in name_list:
|
||
|
img_list = os.listdir(os.path.join(test_path,name))
|
||
|
for img in img_list:
|
||
|
order_name.append(name)
|
||
|
order_path.append(os.path.join(os.path.join(test_path,name),img))
|
||
|
order_img = np.zeros((len(order_path), 3, 112, 112), dtype=np.float32)
|
||
|
for index,img_path in enumerate(order_path):
|
||
|
order_img[index] = load_image(img_path)
|
||
|
print(order_img.shape)
|
||
|
# for name in order_path:
|
||
|
# print(name)
|
||
|
|
||
|
k_v = load_npy("cfp.npy")
|
||
|
start_time = time.time()
|
||
|
order_img = torch.from_numpy(order_img)
|
||
|
|
||
|
batch = 256
|
||
|
now = 0
|
||
|
number = len(order_img)
|
||
|
#number = 1400
|
||
|
for i in range(number):
|
||
|
unknown.append("unknown")
|
||
|
|
||
|
while now < number:
|
||
|
if now+batch < number:
|
||
|
name = findAll(order_img[now:now+batch],model,k_v)
|
||
|
else:
|
||
|
name = findAll(order_img[now:number], model, k_v)
|
||
|
now = now+batch
|
||
|
for na in name:
|
||
|
pred_name.append(na)
|
||
|
print("batch"+str(now))
|
||
|
end_time = time.time()
|
||
|
print("findAll time: " + str(end_time - start_time))
|
||
|
#print(len(pred_name))
|
||
|
right = 0
|
||
|
for i,name in enumerate(pred_name):
|
||
|
if pred_name[i] == order_name[i]:
|
||
|
right += 1
|
||
|
filed = 0
|
||
|
for i, name in enumerate(pred_name):
|
||
|
if pred_name[i] == unknown[i]:
|
||
|
filed += 1
|
||
|
error = 0
|
||
|
for i,name in enumerate(pred_name):
|
||
|
if pred_name[i] != order_name[i]:
|
||
|
error += 1
|
||
|
print(order_name[i]+" "+pred_name[i]+" "+order_path[i])
|
||
|
print("total:" + str(number))
|
||
|
print("right:" + str(right) + " rate:" + str(right / number))
|
||
|
print("filed:" + str(filed) + " rate:" + str(filed / number))
|
||
|
print("error:"+str(error-filed)+" rate:"+str((error-filed)/number))
|