Face/retinaface_arcface.py

763 lines
30 KiB
Python
Raw Permalink Normal View History

2024-07-29 11:24:25 +08:00
from __future__ import print_function
import os
import argparse
import re
import faiss
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from data import cfg_mnet, cfg_re50
from face_api import create_database_from_img, load_arcface_model, findAll
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
import cv2
from models.retinaface import RetinaFace
from utils.box_utils import decode, decode_landm
import time
from face_api import load_arcface_model, load_npy
from skimage import transform as trans
from backbones import iresnet100, iresnet18
#from create_database import findOne, load_npy,findAll
from PIL import Image, ImageDraw,ImageFont
parser = argparse.ArgumentParser(description='Retinaface')
parser.add_argument('-m', '--trained_model', default='./weights/mobilenet0.25_Final.pth',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--network', default='mobile0.25', help='Backbone network mobile0.25 or resnet50')
parser.add_argument('--cpu', action="store_true", default=False if torch.cuda.is_available() else True, help='Use cpu inference')
parser.add_argument('--confidence_threshold', default=0.02, type=float, help='confidence_threshold')
parser.add_argument('--top_k', default=5000, type=int, help='top_k')
parser.add_argument('--nms_threshold', default=0.4, type=float, help='nms_threshold')
parser.add_argument('--keep_top_k', default=750, type=int, help='keep_top_k')
parser.add_argument('-s', '--save_image', action="store_true", default=True, help='show detection results')
parser.add_argument('--vis_thres', default=0.6, type=float, help='visualization_threshold')
args = parser.parse_args()
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
print('Missing keys:{}'.format(len(missing_keys)))
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
print('Used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True
def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
print('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(model, pretrained_path, load_to_cpu):
print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
else:
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
def image_to112x112_retinaface():
torch.set_grad_enabled(False)
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
# net and model
net = RetinaFace(cfg=cfg, phase = 'test')
net = load_model(net, args.trained_model, args.cpu)
net.eval()
print('Finished loading model!')
#print(net)
cudnn.benchmark = True
device = torch.device("cpu" if args.cpu else "cuda")
net = net.to(device)
resize = 1
input_path = r"D:\Download\out\cfp"
output_path = "D:\Download\out\cfp_align"
folder1 = os.listdir(input_path)
count = 0
count2 =0
for f in folder1:
output_name_path = os.path.join(output_path, f)
if os.path.exists(output_name_path) == 0:
os.makedirs(output_name_path)
img_name_path = os.path.join(input_path, f)
img_list = os.listdir(img_name_path)
for img in img_list:
count2 +=1
print(count2)
path = os.path.join(img_name_path, img)
align_img_path = os.path.join(output_name_path, img)
# print(path)
frame = cv2.imread(path)
h, w = frame.shape[:2]
img = np.float32(frame)
im_height, im_width, _ = img.shape
scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)
scale = scale.to(device)
tic = time.time()
loc, conf, landms = net(img) # forward pass
print('net forward time: {:.4f}'.format(time.time() - tic))
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2]])
scale1 = scale1.to(device)
landms = landms * scale1 / resize
landms = landms.cpu().numpy()
# ignore low scores
inds = np.where(scores > args.confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:args.top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, args.nms_threshold)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
dets = dets[:args.keep_top_k, :]
landms = landms[:args.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
score = 500
# show image
if args.save_image:
dst = []
for i, det in enumerate(dets):
if det[4] < args.vis_thres:
continue
center_x = (det[2] + det[0]) / 2
center_y = (det[3] + det[1]) / 2
if abs(center_x - 125) + abs(center_y - 125) < score:
score = abs(center_x - 125) + abs(center_y - 125)
dst = np.reshape(landms[i], (5, 2))
if len(dst) > 0:
src1 = np.array([
[38.3814, 51.6963],
[73.6186, 51.5014],
[56.1120, 71.7366],
[41.6361, 92.3655],
[70.8167, 92.2041]], dtype=np.float32)
tform = trans.SimilarityTransform()
tform.estimate(dst, src1)
M = tform.params[0:2, :]
if w < 112 or h < 112:
count += 1
#print(align_img_path)
continue
frame = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
img112 = frame[0:112, 0:112, :]
cv2.imwrite(align_img_path, img112)
print(">112 number"+str(count))
def sfz_to112x112_retinaface(arcface_model,cpu_or_cuda):
torch.set_grad_enabled(False)
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
# net and model
net = RetinaFace(cfg=cfg, phase = 'test')
net = load_model(net, args.trained_model, args.cpu)
net.eval()
print('Finished loading model!')
#print(net)
cudnn.benchmark = True
device = torch.device("cpu" if args.cpu else "cuda")
net = net.to(device)
resize = 1
input_path = r"D:\Download\out\alig_students_all"
output_path = r"D:\Download\out\alig_students_all"
folder1 = os.listdir(input_path)
count = 0
count2 =0
print(len(folder1))
# print(folder1[0][:-4])
# return 0
order_img = []
order_name = []
tic = time.time()
for img_name in folder1[:2500]:
# output_name_path = os.path.join(output_path, img_name)
# if os.path.exists(output_name_path) == 0:
# os.makedirs(output_name_path)
img_name_path = os.path.join(input_path, img_name)
#img_list = os.listdir(img_name_path)
count2 += 1
if (count2 % 1000 == 0):
print('net forward time: {:.4f}'.format(time.time() - tic))
print(count2)
if len(order_img) > 0:
order_img = np.array(order_img)
order_img = order_img.transpose((0, 3, 1, 2))
order_img = np.array(order_img, dtype=np.float32)
order_img -= 127.5
order_img /= 127.5
# order_img = np.array(order_img)
# print(order_img.shape)
# print(len(order_name))
create_database_from_img(order_name, order_img, arcface_model, "./Database/sfz_test.npy", cpu_or_cuda)
order_img = []
order_name = []
tic = time.time()
# if img_name[19] != "1":
# continue
#path = os.path.join(img_name_path, img)
align_img_path = os.path.join(output_path, img_name)
# print(path)
#frame = cv2.imdecode(np.fromfile(img_name_path, dtype=np.uint8), cv2.IMREAD_COLOR)
try:
frame = cv2.imdecode(np.fromfile(img_name_path, dtype=np.uint8), cv2.IMREAD_COLOR)
h, w, d = frame.shape
except AttributeError:
print(img_name)
continue
if d == 1:
continue
factor = h / w
if (w > 1000):
frame = cv2.resize(frame, (600, int(600 * factor)))
h, w = frame.shape[:2]
img = np.float32(frame)
im_height, im_width, _ = img.shape
scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)
scale = scale.to(device)
#tic = time.time()
loc, conf, landms = net(img) # forward pass
#print('net forward time: {:.4f}'.format(time.time() - tic))
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2]])
scale1 = scale1.to(device)
landms = landms * scale1 / resize
landms = landms.cpu().numpy()
# ignore low scores
inds = np.where(scores > args.confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:args.top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, args.nms_threshold)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
dets = dets[:args.keep_top_k, :]
landms = landms[:args.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
score = 500
# show image
if args.save_image:
dst = []
for i, det in enumerate(dets):
if det[4] < args.vis_thres:
continue
# center_x = (det[2] + det[0]) / 2
# center_y = (det[3] + det[1]) / 2
# if abs(center_x - 125) + abs(center_y - 125) < score:
# score = abs(center_x - 125) + abs(center_y - 125)
dst = np.reshape(landms[i], (5, 2))
if len(dst) > 0:
src1 = np.array([
[38.3814, 51.6963],
[73.6186, 51.5014],
[56.1120, 71.7366],
[41.6361, 92.3655],
[70.8167, 92.2041]], dtype=np.float32)
tform = trans.SimilarityTransform()
tform.estimate(dst, src1)
M = tform.params[0:2, :]
if w < 112 or h < 112:
count += 1
print(img_name_path)
continue
frame = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
img112 = frame[0:112, 0:112, :]
order_img.append(img112)
order_name.append(img_name[:-6])
#cv2.imencode('.jpg', img112)[1].tofile(align_img_path)
#cv2.imwrite(align_img_path, img112)
print(">112 number"+str(count))
if len(order_img) > 0:
order_img = np.array(order_img)
order_img = order_img.transpose((0, 3, 1, 2))
order_img = np.array(order_img, dtype=np.float32)
order_img -= 127.5
order_img /= 127.5
#order_img = np.array(order_img)
# print(order_img.shape)
# print(len(order_name))
create_database_from_img(order_name, order_img, arcface_model, "./Database/sfz_test.npy", cpu_or_cuda)
def count_accuracy(arcface_model,cpu_or_cuda,index ,database_name_list):
torch.set_grad_enabled(False)
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
# net and model
net = RetinaFace(cfg=cfg, phase = 'test')
net = load_model(net, args.trained_model, args.cpu)
net.eval()
print('Finished loading model!')
#print(net)
cudnn.benchmark = True
device = torch.device("cpu" if args.cpu else "cuda")
net = net.to(device)
resize = 1
input_path = r"../face/czrkzp2"
folder1 = os.listdir(input_path)
count = 0
count2 =0
print(len(folder1))
# print(folder1[0][:-4])
# return 0
order_img = []
order_name = []
tic = time.time()
for img_name in folder1[:15000]:
# output_name_path = os.path.join(output_path, img_name)
# if os.path.exists(output_name_path) == 0:
# os.makedirs(output_name_path)
img_name_path = os.path.join(input_path, img_name)
#img_list = os.listdir(img_name_path)
count2 += 1
if (count2 % 5000 == 0):
print('net forward time: {:.4f}'.format(time.time() - tic))
print(count2)
# if len(order_img) > 0:
# order_img = np.array(order_img)
# order_img = order_img.transpose((0, 3, 1, 2))
# order_img = np.array(order_img, dtype=np.float32)
# order_img -= 127.5
# order_img /= 127.5
# # order_img = np.array(order_img)
# # print(order_img.shape)
# # print(len(order_name))
# create_database_from_img(order_name, order_img, arcface_model, "./Database/sfz_test.npy", cpu_or_cuda)
# order_img = []
# order_name = []
# tic = time.time()
if img_name[19] == "1":
continue
#path = os.path.join(img_name_path, img)
#align_img_path = os.path.join(output_path, img_name)
# print(path)
#frame = cv2.imdecode(np.fromfile(img_name_path, dtype=np.uint8), cv2.IMREAD_COLOR)
try:
frame = cv2.imread(img_name_path)
h, w, d = frame.shape
except AttributeError:
print(img_name)
continue
if d == 1:
continue
factor = h / w
if (w > 1000):
frame = cv2.resize(frame, (600, int(600 * factor)))
h, w = frame.shape[:2]
img = np.float32(frame)
im_height, im_width, _ = img.shape
scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)
scale = scale.to(device)
#tic = time.time()
loc, conf, landms = net(img) # forward pass
#print('net forward time: {:.4f}'.format(time.time() - tic))
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2]])
scale1 = scale1.to(device)
landms = landms * scale1 / resize
landms = landms.cpu().numpy()
# ignore low scores
inds = np.where(scores > args.confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:args.top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, args.nms_threshold)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
dets = dets[:args.keep_top_k, :]
landms = landms[:args.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
score = 500
# show image
if args.save_image:
dst = []
for i, det in enumerate(dets):
if det[4] < args.vis_thres:
continue
# center_x = (det[2] + det[0]) / 2
# center_y = (det[3] + det[1]) / 2
# if abs(center_x - 125) + abs(center_y - 125) < score:
# score = abs(center_x - 125) + abs(center_y - 125)
dst = np.reshape(landms[i], (5, 2))
if len(dst) > 0:
src1 = np.array([
[38.3814, 51.6963],
[73.6186, 51.5014],
[56.1120, 71.7366],
[41.6361, 92.3655],
[70.8167, 92.2041]], dtype=np.float32)
tform = trans.SimilarityTransform()
tform.estimate(dst, src1)
M = tform.params[0:2, :]
if w < 112 or h < 112:
count += 1
print(img_name_path)
continue
frame = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
img112 = frame[0:112, 0:112, :]
order_img.append(img112)
order_name.append(img_name)
#cv2.imencode('.jpg', img112)[1].tofile(align_img_path)
#cv2.imwrite(align_img_path, img112)
print(">112 number"+str(count))
if len(order_img) > 0:
order_img = np.array(order_img)
order_img = order_img.transpose((0, 3, 1, 2))
order_img = np.array(order_img, dtype=np.float32)
order_img -= 127.5
order_img /= 127.5
#order_img = np.array(order_img)
# print(order_img.shape)
# print(len(order_name))
count_acc(order_name,order_img,arcface_model,index ,database_name_list,cpu_or_cuda)
def count_acc(order_name,order_img,model,index ,database_name_list,cpu_or_cuda):
pred_name = []
unknown = []
print(order_img.shape)
start_time = time.time()
# order_img = torch.from_numpy(order_img)
# order_img = order_img.to(torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu"))
batch = 256
now = 0
number = len(order_img)
# number = 1400
for i in range(number):
unknown.append("unknown")
while now < number:
if now + batch < number:
name = findAll(order_img[now:now + batch], model, index ,database_name_list, cpu_or_cuda)
else:
name = findAll(order_img[now:number], model, index ,database_name_list, cpu_or_cuda)
now = now + batch
for na in name:
pred_name.append(na)
print("batch" + str(now))
end_time = time.time()
print("findAll time: " + str(end_time - start_time))
# print(len(pred_name))
right = 0
for i, name in enumerate(pred_name):
if pred_name[i] == order_name[i][:-6]:
right += 1
filed = 0
for i, name in enumerate(pred_name):
if pred_name[i] == unknown[i]:
filed += 1
#print(order_name[i])
error = 0
print("----------------")
for i, name in enumerate(pred_name):
if pred_name[i] != order_name[i][:-6]:
error += 1
#print(order_name[i] + " " + pred_name[i] + " ")
#print(order_name)
#print(pred_name)
print("total:" + str(number))
print("right:" + str(right+filed) + " rate:" + str((filed+right) / number))
#print("filed:" + str(filed) + " rate:" + str(filed / number))
print("error:" + str(error - filed) + " rate:" + str((error - filed) / number))
# if __name__ == '__main__':
# torch.set_grad_enabled(False)
# cfg = None
# if args.network == "mobile0.25":
# cfg = cfg_mnet
# elif args.network == "resnet50":
# cfg = cfg_re50
# # net and model
# net = RetinaFace(cfg=cfg, phase = 'test')
# net = load_model(net, args.trained_model, args.cpu)
# net.eval()
# print('Finished loading model!')
# #print(net)
# cudnn.benchmark = True
# device = torch.device("cpu" if args.cpu else "cuda")
# net = net.to(device)
#
# resize = 1
#
# # testing begin
# cap = cv2.VideoCapture("rtsp://47.108.74.82:8557/h264")
# ret, frame = cap.read()
# h, w = frame.shape[:2]
# fps = cap.get(cv2.CAP_PROP_FPS)
# size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
# int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
# #out = cv2.VideoWriter('out.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, size)
# out = cv2.VideoWriter('ttttttt.avi', cv2.VideoWriter_fourcc(*'XVID'), fps, size)
# number = 0
#
# model = iresnet100()
# model.load_state_dict(torch.load("./model/backbone100.pth", map_location="cpu"))
# model.eval()
# k_v = load_npy("./Database/student.npy")
#
# while ret:
# tic = time.time()
# img = np.float32(frame)
# im_height, im_width, _ = img.shape
# scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
# img -= (104, 117, 123)
# img = img.transpose(2, 0, 1)
# img = torch.from_numpy(img).unsqueeze(0)
# img = img.to(device)
# scale = scale.to(device)
#
# loc, conf, landms = net(img) # forward pass
#
#
# priorbox = PriorBox(cfg, image_size=(im_height, im_width))
# priors = priorbox.forward()
# priors = priors.to(device)
# prior_data = priors.data
# boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
# boxes = boxes * scale / resize
# boxes = boxes.cpu().numpy()
# scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
# landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
# scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
# img.shape[3], img.shape[2], img.shape[3], img.shape[2],
# img.shape[3], img.shape[2]])
# scale1 = scale1.to(device)
# landms = landms * scale1 / resize
# landms = landms.cpu().numpy()
#
# # ignore low scores
# inds = np.where(scores > args.confidence_threshold)[0]
# boxes = boxes[inds]
# landms = landms[inds]
# scores = scores[inds]
#
# # keep top-K before NMS
# order = scores.argsort()[::-1][:args.top_k]
# boxes = boxes[order]
# landms = landms[order]
# scores = scores[order]
#
# # do NMS
# dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
# keep = py_cpu_nms(dets, args.nms_threshold)
# # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
# dets = dets[keep, :]
# landms = landms[keep]
#
# # keep top-K faster NMS
# dets = dets[:args.keep_top_k, :]
# landms = landms[:args.keep_top_k, :]
#
# dets = np.concatenate((dets, landms), axis=1)
# face_list = []
# name_list = []
# #print(dets[:4])
# print('net forward time: {:.4f}'.format(time.time() - tic))
# start_time = time.time()
# for i, det in enumerate(dets):
# if det[4] < args.vis_thres:
# continue
# boxes, score = det[:4], det[4]
# dst = np.reshape(landms[i],(5,2))
# #print(dst.shape)
# src1 = np.array([
# [38.3814, 51.6963],
# [73.6186, 51.5014],
# [56.1120, 71.7366],
# [41.6361, 92.3655],
# [70.8167, 92.2041]], dtype=np.float32)
# #print(src1.shape)
# tform = trans.SimilarityTransform()
# tform.estimate(dst, src1)
# M = tform.params[0:2, :]
# frame2 = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
# img112 = frame2[0:112, 0:112, :]
# # cv2.imwrite("./img/man"+str(count)+".jpg", img112)
# # count += 1
# face_list.append(img112)
#
# if len(face_list) != 0:
# face_list = np.array(face_list)
# face_list = face_list.transpose((0, 3, 1, 2))
# face_list = np.array(face_list, dtype=np.float32)
# face_list -= 127.5
# face_list /= 127.5
# print(face_list.shape)
# face_list = torch.from_numpy(face_list)
#
# name_list = findAll(face_list, model, k_v)
# end_time = time.time()
# print("findOneframe time: " + str(end_time - start_time))
# start_time = time.time()
# img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# draw = ImageDraw.Draw(img_PIL)
# font = ImageFont.truetype("font.ttf", 22)
# for i, det in enumerate(dets):
# if det[4] < args.vis_thres:
# continue
# boxes, score = det[:4], det[4]
# #print(name_list)
# name = name_list[i]
# mo = r'[\u4e00-\u9fa5]*'
# name = re.match(mo, name).group(0)
# if not isinstance(name, np.unicode):
# name = name.decode('utf8')
# draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
# draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green", width=3)
# frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
# cv2.imshow('out', frame)
# out.write(frame)
# end_time = time.time()
# print("drawOneframe time: " + str(end_time - start_time))
# # Press Q on keyboard to stop recording
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
# ret, frame = cap.read()
# cap.release()
# out.release()
# cv2.destroyAllWindows()
if __name__ == '__main__':
cpu_or_cuda = "cuda" if torch.cuda.is_available() else "cpu"
arcface_model = load_arcface_model("./model/backbone100.pth", cpu_or_cuda=cpu_or_cuda)
k_v = load_npy("./Database/sfz_test.npy")
database_name_list = list(k_v.keys())
vector_list = np.array(list(k_v.values()))
print(vector_list.shape)
# print(database_name_list)
nlist = 500
quantizer = faiss.IndexFlatL2(512) # the other index
index = faiss.IndexIVFFlat(quantizer, 512, nlist, faiss.METRIC_L2)
index.train(vector_list)
# index = faiss.IndexFlatL2(512)
index.add(vector_list)
index.nprobe = 50
count_accuracy(arcface_model, cpu_or_cuda, index, database_name_list)
# sfz_to112x112_retinaface(arcface_model,cpu_or_cuda)