Face/create_database.py

168 lines
5.5 KiB
Python

import os
import time
import re
import torch
import cv2
import numpy as np
from backbones import iresnet50,iresnet18,iresnet100
def load_image(img_path):
#img = cv2.imread(img_path)
img = cv2.imdecode(np.fromfile(img_path,dtype=np.uint8),cv2.IMREAD_COLOR)
img = img.transpose((2, 0, 1))
img = img[np.newaxis, :, :, :]
img = np.array(img, dtype=np.float32)
img -= 127.5
img /= 127.5
return img
def findEuclideanDistance(source_representation, test_representation):
euclidean_distance = source_representation - test_representation
euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance))
euclidean_distance = np.sqrt(euclidean_distance)
return euclidean_distance
def findCosineDistance(source_representation, test_representation):
a = np.matmul(np.transpose(source_representation), test_representation)
b = np.sum(np.multiply(source_representation, source_representation))
c = np.sum(np.multiply(test_representation, test_representation))
return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
def l2_normalize(x):
return x / np.sqrt(np.sum(np.multiply(x, x)))
def cosin_metric(x1, x2):
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
def load_npy(path):
data = np.load(path,allow_pickle=True)
data = data.item()
return data
def create_database(path,model,database_path):
name_list = os.listdir(path)
k_v = {}
if os.path.exists(database_path):
k_v = np.load(database_path, allow_pickle=True)
k_v = k_v.item()
for name in name_list:
img_path = os.listdir(os.path.join(path,name))
for img_name in img_path[:1]:
img = load_image(os.path.join(path,name,img_name))
img = torch.from_numpy(img)
with torch.no_grad():
pred = model(img)
pred = pred.numpy()
k_v[name] = l2_normalize(pred)
np.save(database_path, k_v)
def create_database_batch(path,model,database_path):
name_list = os.listdir(path)
k_v = {}
if os.path.exists(database_path):
k_v = np.load(database_path, allow_pickle=True)
k_v = k_v.item()
batch = 256
order_name = []
order_path = []
emb_list = []
for name in name_list:
img_path = os.listdir(os.path.join(path,name))
for img_name in img_path[:1]:
order_name.append(name)
order_path.append(os.path.join(path,name,img_name))
order_img = np.zeros((len(order_path), 3, 112, 112), dtype=np.float32)
for index, img_path in enumerate(order_path):
order_img[index] = load_image(img_path)
print(order_img.shape)
order_img = torch.from_numpy(order_img)
now = 0
number = len(order_img)
with torch.no_grad():
while now < number:
if now + batch < number:
emb = model(order_img[now:now+batch])
else:
emb = model(order_img[now:])
now = now + batch
for em in emb:
emb_list.append(em)
print("batch"+str(now))
for i, emb in enumerate(emb_list):
k_v[order_name[i]] = l2_normalize(emb.numpy())
np.save(database_path, k_v)
def add_one(img,model,name,database_path):
img = torch.from_numpy(img)
with torch.no_grad():
pred = model(img)
pred = pred.numpy()
k_v = {}
if os.path.exists(database_path):
k_v = np.load(database_path, allow_pickle=True)
k_v = k_v.item()
k_v[name] = l2_normalize(pred)
np.save(database_path, k_v)
def findmindistance(pred,threshold,k_v):
distance = 10
most_like = ""
for name in k_v.keys():
tmp = findEuclideanDistance(k_v[name],pred)
if distance > tmp:
distance = tmp
most_like = name
if distance < threshold:
return most_like
else:
return -1
def findOne(img,model,k_v):
with torch.no_grad():
start_time = time.time()
pred = model(img)
end_time = time.time()
#print("predOne time: " + str(end_time - start_time))
pred = pred.numpy()
name = findmindistance(l2_normalize(pred),threshold=1.20,k_v=k_v)
if name != -1:
return name
else:
return "unknown"
def findAll(imglist,model,k_v):
with torch.no_grad():
name_list = []
pred = model(imglist)
pred = pred.numpy()
for pr in pred:
name = findmindistance(l2_normalize(pr),threshold=1.20,k_v=k_v)
if name != -1:
name_list.append(name)
else:
name_list.append("unknown")
return name_list
if __name__=='__main__':
model = iresnet100()
model.load_state_dict(torch.load("./model/backbone100.pth", map_location="cpu"))
model.eval()
#img = load_image(r"D:\Download\out\facedatabase\man.jpg")
#img = load_image(r"D:\Download\out\facedatabase\man6.jpg")
# img = load_image(r"D:\Download\out\alig_students\student.jpg")
# print(img.shape)
#
# k_v = load_npy("./Database/student.npy")
# start_time = time.time()
# img = torch.from_numpy(img)
# name = findOne(img,model,k_v)
# mo = r'[\u4e00-\u9fa5]*'
# name = re.match(mo,name)
# print(name.group(0))
# end_time = time.time()
# print("findOne time: " + str(end_time - start_time))
#create_database_batch(r"D:\Download\out\alig_students",model,"./Database/student.npy")
create_database_batch(r"D:\Download\out\cfp_database", model, "cfp.npy")
#add_one(img,model,"Arminio_Fraga","centerface_lfw.npy")