484 lines
18 KiB
Python
484 lines
18 KiB
Python
from __future__ import print_function
|
||
import re
|
||
import time
|
||
import cv2
|
||
import torch
|
||
import torch.backends.cudnn as cudnn
|
||
import numpy as np
|
||
from skimage import transform as trans
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
from data import cfg_mnet, cfg_re50
|
||
from layers.functions.prior_box import PriorBox
|
||
from utils.nms.py_cpu_nms import py_cpu_nms
|
||
from models.retinaface import RetinaFace
|
||
from utils.box_utils import decode, decode_landm
|
||
|
||
threshold = 1.05
|
||
ppi = 1280
|
||
step = 3
|
||
|
||
class ConfRetinaface(object):
|
||
def __init__(self, trained_model, network, cpu, confidence_threshold, top_k, nms_threshold, keep_top_k, vis_thres):
|
||
self.trained_model = trained_model
|
||
self.network = network
|
||
self.cpu = cpu
|
||
self.confidence_threshold = confidence_threshold
|
||
self.top_k = top_k
|
||
self.nms_threshold = nms_threshold
|
||
self.keep_top_k = keep_top_k
|
||
self.vis_thres = vis_thres
|
||
|
||
|
||
def set_retinaface_conf(cpu_or_cuda):
|
||
args = ConfRetinaface(trained_model='./weights/mobilenet0.25_Final.pth',
|
||
network='mobile0.25',
|
||
cpu=True if cpu_or_cuda == 'cpu' else False,
|
||
confidence_threshold=0.02,
|
||
top_k=5000,
|
||
nms_threshold=0.4,
|
||
keep_top_k=750,
|
||
vis_thres=0.6)
|
||
return args
|
||
|
||
|
||
def check_keys(model, pretrained_state_dict):
|
||
ckpt_keys = set(pretrained_state_dict.keys())
|
||
model_keys = set(model.state_dict().keys())
|
||
used_pretrained_keys = model_keys & ckpt_keys
|
||
unused_pretrained_keys = ckpt_keys - model_keys
|
||
missing_keys = model_keys - ckpt_keys
|
||
print('Missing keys:{}'.format(len(missing_keys)))
|
||
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
|
||
print('Used keys:{}'.format(len(used_pretrained_keys)))
|
||
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
|
||
return True
|
||
|
||
|
||
def remove_prefix(state_dict, prefix):
|
||
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
|
||
print('remove prefix \'{}\''.format(prefix))
|
||
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
||
return {f(key): value for key, value in state_dict.items()}
|
||
|
||
|
||
def load_model(model, pretrained_path, load_to_cpu):
|
||
print('Loading pretrained model from {}'.format(pretrained_path))
|
||
if load_to_cpu:
|
||
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
|
||
else:
|
||
device = torch.cuda.current_device()
|
||
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
|
||
if "state_dict" in pretrained_dict.keys():
|
||
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
|
||
else:
|
||
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
|
||
check_keys(model, pretrained_dict)
|
||
model.load_state_dict(pretrained_dict, strict=False)
|
||
return model
|
||
|
||
|
||
# 加载retinaface模型
|
||
def load_retinaface_model(args):
|
||
torch.set_grad_enabled(False)
|
||
cfg = None
|
||
if args.network == "mobile0.25":
|
||
cfg = cfg_mnet
|
||
elif args.network == "resnet50":
|
||
cfg = cfg_re50
|
||
# net and model
|
||
net = RetinaFace(cfg=cfg, phase='test')
|
||
net = load_model(net, args.trained_model, args.cpu)
|
||
net.eval()
|
||
cudnn.benchmark = True
|
||
device = torch.device("cpu" if args.cpu else "cuda")
|
||
net = net.to(device)
|
||
print('Finished loading model!')
|
||
return net
|
||
|
||
|
||
# 计算两个特征向量的欧式距离
|
||
def findEuclideanDistance(source_representation, test_representation):
|
||
euclidean_distance = source_representation - test_representation
|
||
euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance))
|
||
euclidean_distance = np.sqrt(euclidean_distance)
|
||
return euclidean_distance
|
||
|
||
|
||
# 归一化欧氏距离
|
||
def l2_normalize(x):
|
||
return x / np.sqrt(np.sum(np.multiply(x, x)))
|
||
|
||
|
||
# 计算此特征向量与人脸库中的哪个人脸特征向量距离最近
|
||
def findmindistance(pred, threshold, k_v):
|
||
distance = 10
|
||
most_like = ""
|
||
for name in k_v.keys():
|
||
tmp = findEuclideanDistance(k_v[name], pred)
|
||
if distance > tmp:
|
||
distance = tmp
|
||
most_like = name
|
||
if distance < threshold:
|
||
return most_like
|
||
else:
|
||
return "unknown"
|
||
|
||
#
|
||
def faiss_find_face(pred,index ,database_name_list):
|
||
#print(len(database_name_list))
|
||
start_time = time.time()
|
||
D, I = index.search(pred, 1)
|
||
name_list = []
|
||
end_time = time.time()
|
||
print("faiss cost %fs" % (end_time - start_time))
|
||
print(D, I)
|
||
# if D[0][0] < threshold:
|
||
# print(database_name_list[I[0][0]])
|
||
# return database_name_list[I[0][0]]
|
||
# else:
|
||
# return "unknown"
|
||
for i,index in enumerate(I):
|
||
if D[i][0] < threshold:
|
||
#print(database_name_list[I[0][0]])
|
||
name_list.append(database_name_list[index[0]])
|
||
else:
|
||
name_list.append("unknown")
|
||
return name_list
|
||
|
||
# 从人脸库中找到传入的人脸列表中的所有人脸
|
||
def findAll(imglist, model, index ,database_name_list, k_v, cpu_or_cuda):
|
||
start_time = time.time()
|
||
imglist = torch.from_numpy(imglist)
|
||
imglist = imglist.to(torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu"))
|
||
with torch.no_grad():
|
||
name_list = []
|
||
pred = model(imglist)
|
||
pred = pred.cpu().numpy()
|
||
print("predOne time: " + str(time.time() - start_time))
|
||
#print(pred.shape)
|
||
start_time = time.time()
|
||
#name_list = faiss_find_face(l2_normalize(pred), index, database_name_list)
|
||
for pr in pred:
|
||
name = findmindistance(l2_normalize(pr), threshold=threshold, k_v=k_v)
|
||
print(name)
|
||
# print(l2_normalize(pr).shape)
|
||
#pr = np.expand_dims(l2_normalize(pr), 0)
|
||
#print(pr.shape)
|
||
#name = faiss_find_face(pr,index ,database_name_list)
|
||
if name != "unknown":
|
||
mo = r'[\u4e00-\u9fa5_a-zA-Z]*'
|
||
name = re.match(mo, name)
|
||
name_list.append(name.group(0))
|
||
else:
|
||
name_list.append("unknown")
|
||
#name_list.append(name)
|
||
print("findOne time: " + str(time.time() - start_time))
|
||
return name_list
|
||
|
||
|
||
# 检测单张人脸,返回1x3x112x112的数组
|
||
def detect_one(path, net, args):
|
||
cfg = None
|
||
if args.network == "mobile0.25":
|
||
cfg = cfg_mnet
|
||
elif args.network == "resnet50":
|
||
cfg = cfg_re50
|
||
|
||
device = torch.device("cpu" if args.cpu else "cuda")
|
||
resize = 1
|
||
|
||
# testing begin
|
||
frame = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||
h, w = frame.shape[:2]
|
||
factor = h / w
|
||
if (w > 1000):
|
||
frame = cv2.resize(frame, (600, int(600 * factor)))
|
||
h, w = frame.shape[:2]
|
||
|
||
tic = time.time()
|
||
img = np.float32(frame)
|
||
im_height, im_width, _ = img.shape
|
||
scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||
img -= (104, 117, 123)
|
||
img = img.transpose(2, 0, 1)
|
||
img = torch.from_numpy(img).unsqueeze(0)
|
||
img = img.to(device)
|
||
scale = scale.to(device)
|
||
|
||
loc, conf, landms = net(img) # forward pass
|
||
#print(loc.shape,landms.shape,conf.shape)
|
||
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
|
||
priors = priorbox.forward()
|
||
priors = priors.to(device)
|
||
prior_data = priors.data
|
||
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
|
||
boxes = boxes * scale / resize
|
||
boxes = boxes.cpu().numpy()
|
||
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
||
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
|
||
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
|
||
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
|
||
img.shape[3], img.shape[2]])
|
||
scale1 = scale1.to(device)
|
||
landms = landms * scale1 / resize
|
||
landms = landms.cpu().numpy()
|
||
|
||
# ignore low scores
|
||
inds = np.where(scores > args.confidence_threshold)[0]
|
||
boxes = boxes[inds]
|
||
landms = landms[inds]
|
||
scores = scores[inds]
|
||
|
||
# keep top-K before NMS
|
||
order = scores.argsort()[::-1][:args.top_k]
|
||
boxes = boxes[order]
|
||
landms = landms[order]
|
||
scores = scores[order]
|
||
|
||
# do NMS
|
||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
||
keep = py_cpu_nms(dets, args.nms_threshold)
|
||
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
|
||
dets = dets[keep, :]
|
||
landms = landms[keep]
|
||
|
||
# keep top-K faster NMS
|
||
dets = dets[:args.keep_top_k, :]
|
||
landms = landms[:args.keep_top_k, :]
|
||
|
||
dets = np.concatenate((dets, landms), axis=1)
|
||
face_list = []
|
||
box_and_point = []
|
||
# print(dets[:4])
|
||
# print('net forward time: {:.4f}'.format(time.time() - tic))
|
||
print(len(dets))
|
||
for i, det in enumerate(dets):
|
||
|
||
if det[4] < args.vis_thres:
|
||
continue
|
||
box_and_point.append(det)
|
||
dst = np.reshape(landms[i], (5, 2))
|
||
# print(dst.shape)
|
||
src1 = np.array([
|
||
[38.3814, 51.6963],
|
||
[73.6186, 51.5014],
|
||
[56.1120, 71.7366],
|
||
[41.6361, 92.3655],
|
||
[70.8167, 92.2041]], dtype=np.float32)
|
||
# print(src1.shape)
|
||
tform = trans.SimilarityTransform()
|
||
tform.estimate(dst, src1)
|
||
M = tform.params[0:2, :]
|
||
frame2 = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
|
||
img112 = frame2[0:112, 0:112, :]
|
||
# cv2.imshow('out', img112)
|
||
# cv2.waitKey(0)
|
||
face_list.append(img112)
|
||
if len(face_list) > 0:
|
||
face_list = np.array(face_list)
|
||
face_list = face_list.transpose((0, 3, 1, 2))
|
||
face_list = np.array(face_list, dtype=np.float32)
|
||
face_list -= 127.5
|
||
face_list /= 127.5
|
||
box_and_point = np.array(box_and_point)
|
||
# face_list = torch.from_numpy(face_list)
|
||
# cv2.imshow('out', img112)
|
||
# cv2.waitKey(0)
|
||
return face_list, box_and_point
|
||
|
||
|
||
# 检测视频中的人脸并人脸识别
|
||
def detect_video(video_path, output_path, net, arcface_model, k_v, args):
|
||
tic_total = time.time()
|
||
cfg = None
|
||
if args.network == "mobile0.25":
|
||
cfg = cfg_mnet
|
||
elif args.network == "resnet50":
|
||
cfg = cfg_re50
|
||
device = torch.device("cpu" if args.cpu else "cuda")
|
||
resize = 1
|
||
|
||
# testing begin
|
||
cap = cv2.VideoCapture(video_path)
|
||
ret, frame = cap.read()
|
||
h, w = frame.shape[:2]
|
||
factor = 0
|
||
if (w > ppi):
|
||
factor = h / w
|
||
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
|
||
h, w = frame.shape[:2]
|
||
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
size = (w, h)
|
||
# size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
||
# int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||
# out = cv2.VideoWriter('out.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, size)
|
||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'XVID'), fps, size)
|
||
number = step
|
||
dets = []
|
||
name_list = []
|
||
font = ImageFont.truetype("font.ttf", 22)
|
||
priorbox = PriorBox(cfg, image_size=(h, w))
|
||
priors = priorbox.forward()
|
||
priors = priors.to(device)
|
||
prior_data = priors.data
|
||
|
||
scale = torch.Tensor([w, h, w, h])
|
||
scale = scale.to(device)
|
||
scale1 = torch.Tensor([w, h, w, h,
|
||
w, h, w, h,
|
||
w, h])
|
||
scale1 = scale1.to(device)
|
||
|
||
src1 = np.array([
|
||
[38.3814, 51.6963],
|
||
[73.6186, 51.5014],
|
||
[56.1120, 71.7366],
|
||
[41.6361, 92.3655],
|
||
[70.8167, 92.2041]], dtype=np.float32)
|
||
# print(src1.shape)
|
||
tform = trans.SimilarityTransform()
|
||
|
||
while ret:
|
||
tic_all = time.time()
|
||
if number == step:
|
||
tic = time.time()
|
||
img = np.float32(frame)
|
||
img -= (104, 117, 123)
|
||
img = img.transpose(2, 0, 1)
|
||
img = torch.from_numpy(img).unsqueeze(0)
|
||
img = img.to(device)
|
||
|
||
loc, conf, landms = net(img) # forward pass
|
||
|
||
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
|
||
boxes = boxes * scale / resize
|
||
boxes = boxes.cpu().numpy()
|
||
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
||
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
|
||
|
||
landms = landms * scale1 / resize
|
||
landms = landms.cpu().numpy()
|
||
|
||
# ignore low scores
|
||
inds = np.where(scores > args.confidence_threshold)[0]
|
||
boxes = boxes[inds]
|
||
landms = landms[inds]
|
||
scores = scores[inds]
|
||
|
||
# keep top-K before NMS
|
||
order = scores.argsort()[::-1][:args.top_k]
|
||
boxes = boxes[order]
|
||
landms = landms[order]
|
||
scores = scores[order]
|
||
|
||
# do NMS
|
||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
||
keep = py_cpu_nms(dets, args.nms_threshold)
|
||
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
|
||
dets = dets[keep, :]
|
||
landms = landms[keep]
|
||
|
||
# keep top-K faster NMS
|
||
dets = dets[:args.keep_top_k, :]
|
||
landms = landms[:args.keep_top_k, :]
|
||
|
||
dets = np.concatenate((dets, landms), axis=1)
|
||
face_list = []
|
||
name_list = []
|
||
# print(dets[:4])
|
||
print('net forward time: {:.4f}'.format(time.time() - tic))
|
||
start_time = time.time()
|
||
for i, det in enumerate(dets[:4]):
|
||
if det[4] < args.vis_thres:
|
||
continue
|
||
boxes, score = det[:4], det[4]
|
||
dst = np.reshape(landms[i], (5, 2))
|
||
# print(dst.shape)
|
||
|
||
tform.estimate(dst, src1)
|
||
M = tform.params[0:2, :]
|
||
frame2 = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
|
||
img112 = frame2[0:112, 0:112, :]
|
||
face_list.append(img112)
|
||
|
||
if len(face_list) != 0:
|
||
face_list = np.array(face_list)
|
||
face_list = face_list.transpose((0, 3, 1, 2))
|
||
face_list = np.array(face_list, dtype=np.float32)
|
||
face_list -= 127.5
|
||
face_list /= 127.5
|
||
print(face_list.shape)
|
||
# face_list = torch.from_numpy(face_list)
|
||
name_list = findAll(face_list, arcface_model, k_v, "cpu" if args.cpu else "cuda")
|
||
end_time = time.time()
|
||
print("findOneframe time: " + str(end_time - start_time))
|
||
start_time = time.time()
|
||
if (len(dets) != 0):
|
||
for i, det in enumerate(dets[:4]):
|
||
if det[4] < args.vis_thres:
|
||
continue
|
||
boxes, score = det[:4], det[4]
|
||
cv2.rectangle(frame, (int(boxes[0]), int(boxes[1])), (int(boxes[2]), int(boxes[3])), (2, 255, 0), 1)
|
||
|
||
# if (len(dets) != 0):
|
||
# img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||
# draw = ImageDraw.Draw(img_PIL)
|
||
#
|
||
# for i, det in enumerate(dets[:4]):
|
||
# if det[4] < args.vis_thres:
|
||
# continue
|
||
# boxes, score = det[:4], det[4]
|
||
# # print(name_list)
|
||
# name = name_list[i]
|
||
# if not isinstance(name, np.unicode):
|
||
# name = name.decode('utf8')
|
||
# draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
|
||
# draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green", width=3)
|
||
# frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
|
||
#cv2.imshow('out', frame)
|
||
#cv2.waitKey(0)
|
||
out.write(frame)
|
||
end_time = time.time()
|
||
print("drawOneframe time: " + str(end_time - start_time))
|
||
# Press Q on keyboard to stop recording
|
||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
# break
|
||
ret, frame = cap.read()
|
||
number = 0
|
||
if (ret != 0 and factor != 0):
|
||
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
|
||
else:
|
||
number += 1
|
||
if (len(dets) != 0):
|
||
img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(img_PIL)
|
||
for i, det in enumerate(dets[:4]):
|
||
if det[4] < args.vis_thres:
|
||
continue
|
||
boxes, score = det[:4], det[4]
|
||
# print(name_list)
|
||
name = name_list[i]
|
||
if not isinstance(name, np.unicode):
|
||
name = name.decode('utf8')
|
||
draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
|
||
draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green",
|
||
width=3)
|
||
frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
|
||
out.write(frame)
|
||
start_time = time.time()
|
||
ret, frame = cap.read()
|
||
if (ret != 0 and factor != 0):
|
||
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
|
||
print("readframe time: " + str(time.time() - start_time))
|
||
print('all time: {:.4f}'.format(time.time() - tic_all))
|
||
cap.release()
|
||
out.release()
|
||
print('total time: {:.4f}'.format(time.time() - tic_total))
|
||
#cv2.destroyAllWindows()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
args = set_retinaface_conf()
|
||
print(args.cpu)
|