Face/retinaface_detect.py

484 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from __future__ import print_function
import re
import time
import cv2
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from skimage import transform as trans
from PIL import Image, ImageDraw, ImageFont
from data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
from models.retinaface import RetinaFace
from utils.box_utils import decode, decode_landm
threshold = 1.05
ppi = 1280
step = 3
class ConfRetinaface(object):
def __init__(self, trained_model, network, cpu, confidence_threshold, top_k, nms_threshold, keep_top_k, vis_thres):
self.trained_model = trained_model
self.network = network
self.cpu = cpu
self.confidence_threshold = confidence_threshold
self.top_k = top_k
self.nms_threshold = nms_threshold
self.keep_top_k = keep_top_k
self.vis_thres = vis_thres
def set_retinaface_conf(cpu_or_cuda):
args = ConfRetinaface(trained_model='./weights/mobilenet0.25_Final.pth',
network='mobile0.25',
cpu=True if cpu_or_cuda == 'cpu' else False,
confidence_threshold=0.02,
top_k=5000,
nms_threshold=0.4,
keep_top_k=750,
vis_thres=0.6)
return args
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
print('Missing keys:{}'.format(len(missing_keys)))
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
print('Used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True
def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
print('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(model, pretrained_path, load_to_cpu):
print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
else:
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
# 加载retinaface模型
def load_retinaface_model(args):
torch.set_grad_enabled(False)
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
# net and model
net = RetinaFace(cfg=cfg, phase='test')
net = load_model(net, args.trained_model, args.cpu)
net.eval()
cudnn.benchmark = True
device = torch.device("cpu" if args.cpu else "cuda")
net = net.to(device)
print('Finished loading model!')
return net
# 计算两个特征向量的欧式距离
def findEuclideanDistance(source_representation, test_representation):
euclidean_distance = source_representation - test_representation
euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance))
euclidean_distance = np.sqrt(euclidean_distance)
return euclidean_distance
# 归一化欧氏距离
def l2_normalize(x):
return x / np.sqrt(np.sum(np.multiply(x, x)))
# 计算此特征向量与人脸库中的哪个人脸特征向量距离最近
def findmindistance(pred, threshold, k_v):
distance = 10
most_like = ""
for name in k_v.keys():
tmp = findEuclideanDistance(k_v[name], pred)
if distance > tmp:
distance = tmp
most_like = name
if distance < threshold:
return most_like
else:
return "unknown"
#
def faiss_find_face(pred,index ,database_name_list):
#print(len(database_name_list))
start_time = time.time()
D, I = index.search(pred, 1)
name_list = []
end_time = time.time()
print("faiss cost %fs" % (end_time - start_time))
print(D, I)
# if D[0][0] < threshold:
# print(database_name_list[I[0][0]])
# return database_name_list[I[0][0]]
# else:
# return "unknown"
for i,index in enumerate(I):
if D[i][0] < threshold:
#print(database_name_list[I[0][0]])
name_list.append(database_name_list[index[0]])
else:
name_list.append("unknown")
return name_list
# 从人脸库中找到传入的人脸列表中的所有人脸
def findAll(imglist, model, index ,database_name_list, k_v, cpu_or_cuda):
start_time = time.time()
imglist = torch.from_numpy(imglist)
imglist = imglist.to(torch.device("cuda" if cpu_or_cuda == "cuda" else "cpu"))
with torch.no_grad():
name_list = []
pred = model(imglist)
pred = pred.cpu().numpy()
print("predOne time: " + str(time.time() - start_time))
#print(pred.shape)
start_time = time.time()
#name_list = faiss_find_face(l2_normalize(pred), index, database_name_list)
for pr in pred:
name = findmindistance(l2_normalize(pr), threshold=threshold, k_v=k_v)
print(name)
# print(l2_normalize(pr).shape)
#pr = np.expand_dims(l2_normalize(pr), 0)
#print(pr.shape)
#name = faiss_find_face(pr,index ,database_name_list)
if name != "unknown":
mo = r'[\u4e00-\u9fa5_a-zA-Z]*'
name = re.match(mo, name)
name_list.append(name.group(0))
else:
name_list.append("unknown")
#name_list.append(name)
print("findOne time: " + str(time.time() - start_time))
return name_list
# 检测单张人脸返回1x3x112x112的数组
def detect_one(path, net, args):
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
device = torch.device("cpu" if args.cpu else "cuda")
resize = 1
# testing begin
frame = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR)
h, w = frame.shape[:2]
factor = h / w
if (w > 1000):
frame = cv2.resize(frame, (600, int(600 * factor)))
h, w = frame.shape[:2]
tic = time.time()
img = np.float32(frame)
im_height, im_width, _ = img.shape
scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)
scale = scale.to(device)
loc, conf, landms = net(img) # forward pass
#print(loc.shape,landms.shape,conf.shape)
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2]])
scale1 = scale1.to(device)
landms = landms * scale1 / resize
landms = landms.cpu().numpy()
# ignore low scores
inds = np.where(scores > args.confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:args.top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, args.nms_threshold)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
dets = dets[:args.keep_top_k, :]
landms = landms[:args.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
face_list = []
box_and_point = []
# print(dets[:4])
# print('net forward time: {:.4f}'.format(time.time() - tic))
print(len(dets))
for i, det in enumerate(dets):
if det[4] < args.vis_thres:
continue
box_and_point.append(det)
dst = np.reshape(landms[i], (5, 2))
# print(dst.shape)
src1 = np.array([
[38.3814, 51.6963],
[73.6186, 51.5014],
[56.1120, 71.7366],
[41.6361, 92.3655],
[70.8167, 92.2041]], dtype=np.float32)
# print(src1.shape)
tform = trans.SimilarityTransform()
tform.estimate(dst, src1)
M = tform.params[0:2, :]
frame2 = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
img112 = frame2[0:112, 0:112, :]
# cv2.imshow('out', img112)
# cv2.waitKey(0)
face_list.append(img112)
if len(face_list) > 0:
face_list = np.array(face_list)
face_list = face_list.transpose((0, 3, 1, 2))
face_list = np.array(face_list, dtype=np.float32)
face_list -= 127.5
face_list /= 127.5
box_and_point = np.array(box_and_point)
# face_list = torch.from_numpy(face_list)
# cv2.imshow('out', img112)
# cv2.waitKey(0)
return face_list, box_and_point
# 检测视频中的人脸并人脸识别
def detect_video(video_path, output_path, net, arcface_model, k_v, args):
tic_total = time.time()
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
device = torch.device("cpu" if args.cpu else "cuda")
resize = 1
# testing begin
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
h, w = frame.shape[:2]
factor = 0
if (w > ppi):
factor = h / w
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
h, w = frame.shape[:2]
fps = cap.get(cv2.CAP_PROP_FPS)
size = (w, h)
# size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
# int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
# out = cv2.VideoWriter('out.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, size)
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'XVID'), fps, size)
number = step
dets = []
name_list = []
font = ImageFont.truetype("font.ttf", 22)
priorbox = PriorBox(cfg, image_size=(h, w))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
scale = torch.Tensor([w, h, w, h])
scale = scale.to(device)
scale1 = torch.Tensor([w, h, w, h,
w, h, w, h,
w, h])
scale1 = scale1.to(device)
src1 = np.array([
[38.3814, 51.6963],
[73.6186, 51.5014],
[56.1120, 71.7366],
[41.6361, 92.3655],
[70.8167, 92.2041]], dtype=np.float32)
# print(src1.shape)
tform = trans.SimilarityTransform()
while ret:
tic_all = time.time()
if number == step:
tic = time.time()
img = np.float32(frame)
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)
loc, conf, landms = net(img) # forward pass
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
landms = landms * scale1 / resize
landms = landms.cpu().numpy()
# ignore low scores
inds = np.where(scores > args.confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:args.top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, args.nms_threshold)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
dets = dets[:args.keep_top_k, :]
landms = landms[:args.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
face_list = []
name_list = []
# print(dets[:4])
print('net forward time: {:.4f}'.format(time.time() - tic))
start_time = time.time()
for i, det in enumerate(dets[:4]):
if det[4] < args.vis_thres:
continue
boxes, score = det[:4], det[4]
dst = np.reshape(landms[i], (5, 2))
# print(dst.shape)
tform.estimate(dst, src1)
M = tform.params[0:2, :]
frame2 = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
img112 = frame2[0:112, 0:112, :]
face_list.append(img112)
if len(face_list) != 0:
face_list = np.array(face_list)
face_list = face_list.transpose((0, 3, 1, 2))
face_list = np.array(face_list, dtype=np.float32)
face_list -= 127.5
face_list /= 127.5
print(face_list.shape)
# face_list = torch.from_numpy(face_list)
name_list = findAll(face_list, arcface_model, k_v, "cpu" if args.cpu else "cuda")
end_time = time.time()
print("findOneframe time: " + str(end_time - start_time))
start_time = time.time()
if (len(dets) != 0):
for i, det in enumerate(dets[:4]):
if det[4] < args.vis_thres:
continue
boxes, score = det[:4], det[4]
cv2.rectangle(frame, (int(boxes[0]), int(boxes[1])), (int(boxes[2]), int(boxes[3])), (2, 255, 0), 1)
# if (len(dets) != 0):
# img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# draw = ImageDraw.Draw(img_PIL)
#
# for i, det in enumerate(dets[:4]):
# if det[4] < args.vis_thres:
# continue
# boxes, score = det[:4], det[4]
# # print(name_list)
# name = name_list[i]
# if not isinstance(name, np.unicode):
# name = name.decode('utf8')
# draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
# draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green", width=3)
# frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
#cv2.imshow('out', frame)
#cv2.waitKey(0)
out.write(frame)
end_time = time.time()
print("drawOneframe time: " + str(end_time - start_time))
# Press Q on keyboard to stop recording
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
ret, frame = cap.read()
number = 0
if (ret != 0 and factor != 0):
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
else:
number += 1
if (len(dets) != 0):
img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_PIL)
for i, det in enumerate(dets[:4]):
if det[4] < args.vis_thres:
continue
boxes, score = det[:4], det[4]
# print(name_list)
name = name_list[i]
if not isinstance(name, np.unicode):
name = name.decode('utf8')
draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green",
width=3)
frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
out.write(frame)
start_time = time.time()
ret, frame = cap.read()
if (ret != 0 and factor != 0):
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
print("readframe time: " + str(time.time() - start_time))
print('all time: {:.4f}'.format(time.time() - tic_all))
cap.release()
out.release()
print('total time: {:.4f}'.format(time.time() - tic_total))
#cv2.destroyAllWindows()
if __name__ == "__main__":
args = set_retinaface_conf()
print(args.cpu)