168 lines
5.5 KiB
Python
168 lines
5.5 KiB
Python
import os
|
|
import time
|
|
import re
|
|
import torch
|
|
import cv2
|
|
import numpy as np
|
|
from backbones import iresnet50,iresnet18,iresnet100
|
|
|
|
def load_image(img_path):
|
|
#img = cv2.imread(img_path)
|
|
img = cv2.imdecode(np.fromfile(img_path,dtype=np.uint8),cv2.IMREAD_COLOR)
|
|
img = img.transpose((2, 0, 1))
|
|
img = img[np.newaxis, :, :, :]
|
|
img = np.array(img, dtype=np.float32)
|
|
img -= 127.5
|
|
img /= 127.5
|
|
return img
|
|
|
|
def findEuclideanDistance(source_representation, test_representation):
|
|
euclidean_distance = source_representation - test_representation
|
|
euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance))
|
|
euclidean_distance = np.sqrt(euclidean_distance)
|
|
return euclidean_distance
|
|
|
|
def findCosineDistance(source_representation, test_representation):
|
|
a = np.matmul(np.transpose(source_representation), test_representation)
|
|
b = np.sum(np.multiply(source_representation, source_representation))
|
|
c = np.sum(np.multiply(test_representation, test_representation))
|
|
return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
|
|
|
|
def l2_normalize(x):
|
|
return x / np.sqrt(np.sum(np.multiply(x, x)))
|
|
|
|
def cosin_metric(x1, x2):
|
|
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
|
|
|
def load_npy(path):
|
|
data = np.load(path,allow_pickle=True)
|
|
data = data.item()
|
|
return data
|
|
|
|
def create_database(path,model,database_path):
|
|
name_list = os.listdir(path)
|
|
k_v = {}
|
|
if os.path.exists(database_path):
|
|
k_v = np.load(database_path, allow_pickle=True)
|
|
k_v = k_v.item()
|
|
for name in name_list:
|
|
img_path = os.listdir(os.path.join(path,name))
|
|
for img_name in img_path[:1]:
|
|
img = load_image(os.path.join(path,name,img_name))
|
|
img = torch.from_numpy(img)
|
|
with torch.no_grad():
|
|
pred = model(img)
|
|
pred = pred.numpy()
|
|
k_v[name] = l2_normalize(pred)
|
|
np.save(database_path, k_v)
|
|
|
|
def create_database_batch(path,model,database_path):
|
|
name_list = os.listdir(path)
|
|
k_v = {}
|
|
if os.path.exists(database_path):
|
|
k_v = np.load(database_path, allow_pickle=True)
|
|
k_v = k_v.item()
|
|
batch = 256
|
|
order_name = []
|
|
order_path = []
|
|
emb_list = []
|
|
for name in name_list:
|
|
img_path = os.listdir(os.path.join(path,name))
|
|
for img_name in img_path[:1]:
|
|
order_name.append(name)
|
|
order_path.append(os.path.join(path,name,img_name))
|
|
order_img = np.zeros((len(order_path), 3, 112, 112), dtype=np.float32)
|
|
for index, img_path in enumerate(order_path):
|
|
order_img[index] = load_image(img_path)
|
|
print(order_img.shape)
|
|
order_img = torch.from_numpy(order_img)
|
|
now = 0
|
|
number = len(order_img)
|
|
with torch.no_grad():
|
|
while now < number:
|
|
if now + batch < number:
|
|
emb = model(order_img[now:now+batch])
|
|
else:
|
|
emb = model(order_img[now:])
|
|
now = now + batch
|
|
for em in emb:
|
|
emb_list.append(em)
|
|
print("batch"+str(now))
|
|
|
|
for i, emb in enumerate(emb_list):
|
|
k_v[order_name[i]] = l2_normalize(emb.numpy())
|
|
np.save(database_path, k_v)
|
|
|
|
def add_one(img,model,name,database_path):
|
|
img = torch.from_numpy(img)
|
|
with torch.no_grad():
|
|
pred = model(img)
|
|
pred = pred.numpy()
|
|
k_v = {}
|
|
if os.path.exists(database_path):
|
|
k_v = np.load(database_path, allow_pickle=True)
|
|
k_v = k_v.item()
|
|
k_v[name] = l2_normalize(pred)
|
|
np.save(database_path, k_v)
|
|
|
|
def findmindistance(pred,threshold,k_v):
|
|
distance = 10
|
|
most_like = ""
|
|
for name in k_v.keys():
|
|
tmp = findEuclideanDistance(k_v[name],pred)
|
|
if distance > tmp:
|
|
distance = tmp
|
|
most_like = name
|
|
if distance < threshold:
|
|
return most_like
|
|
else:
|
|
return -1
|
|
|
|
def findOne(img,model,k_v):
|
|
|
|
with torch.no_grad():
|
|
start_time = time.time()
|
|
pred = model(img)
|
|
end_time = time.time()
|
|
#print("predOne time: " + str(end_time - start_time))
|
|
pred = pred.numpy()
|
|
name = findmindistance(l2_normalize(pred),threshold=1.20,k_v=k_v)
|
|
if name != -1:
|
|
return name
|
|
else:
|
|
return "unknown"
|
|
def findAll(imglist,model,k_v):
|
|
with torch.no_grad():
|
|
name_list = []
|
|
pred = model(imglist)
|
|
pred = pred.numpy()
|
|
for pr in pred:
|
|
name = findmindistance(l2_normalize(pr),threshold=1.20,k_v=k_v)
|
|
if name != -1:
|
|
name_list.append(name)
|
|
else:
|
|
name_list.append("unknown")
|
|
return name_list
|
|
|
|
if __name__=='__main__':
|
|
model = iresnet100()
|
|
model.load_state_dict(torch.load("./model/backbone100.pth", map_location="cpu"))
|
|
model.eval()
|
|
#img = load_image(r"D:\Download\out\facedatabase\man.jpg")
|
|
#img = load_image(r"D:\Download\out\facedatabase\man6.jpg")
|
|
# img = load_image(r"D:\Download\out\alig_students\student.jpg")
|
|
# print(img.shape)
|
|
#
|
|
# k_v = load_npy("./Database/student.npy")
|
|
# start_time = time.time()
|
|
# img = torch.from_numpy(img)
|
|
# name = findOne(img,model,k_v)
|
|
# mo = r'[\u4e00-\u9fa5]*'
|
|
# name = re.match(mo,name)
|
|
# print(name.group(0))
|
|
# end_time = time.time()
|
|
# print("findOne time: " + str(end_time - start_time))
|
|
|
|
#create_database_batch(r"D:\Download\out\alig_students",model,"./Database/student.npy")
|
|
create_database_batch(r"D:\Download\out\cfp_database", model, "cfp.npy")
|
|
#add_one(img,model,"Arminio_Fraga","centerface_lfw.npy") |