Face/realtime_detect.py

283 lines
11 KiB
Python

import argparse
import subprocess
import time
import cv2
import torch
import numpy as np
from skimage import transform as trans
from PIL import Image, ImageDraw, ImageFont
from data import cfg_mnet, cfg_re50
from face_api import load_arcface_model, load_npy
from layers.functions.prior_box import PriorBox
from retinaface_detect import set_retinaface_conf, load_retinaface_model, findAll
from utils.nms.py_cpu_nms import py_cpu_nms
from utils.box_utils import decode, decode_landm
import faiss
ppi = 1280
ppi2 = 1100
step = 3
def detect_rtsp(rtsp, out_rtsp, net, arcface_model, index ,database_name_list, k_v, args):
tic_total = time.time()
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
device = torch.device("cpu" if args.cpu else "cuda")
resize = 1
# testing begin
cap = cv2.VideoCapture(rtsp)
ret, frame = cap.read()
h, w = frame.shape[:2]
factor = 0
if (w > ppi):
factor = h / w
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
h, w = frame.shape[:2]
arf = 1
detect_h, detect_w = frame.shape[:2]
frame_detect = frame
factor2 = 0
if (w > ppi2):
factor2 = h / w
frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2)))
detect_h, detect_w = frame_detect.shape[:2]
arf = w/detect_w
print(w,h)
print(detect_w,detect_h)
fps = cap.get(cv2.CAP_PROP_FPS)
#print(fps)
size = (w, h)
sizeStr = str(size[0]) + 'x' + str(size[1])
if(out_rtsp.startswith("rtsp")):
command = ['ffmpeg',
'-y', '-an',
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-pix_fmt', 'bgr24',
'-s', sizeStr,
'-r', "25",
'-i', '-',
'-c:v', 'libx265',
'-b:v', '3000k',
'-pix_fmt', 'yuv420p',
'-preset', 'ultrafast',
'-f', 'rtsp',
out_rtsp]
pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE)
#out = cv2.VideoWriter("output.avi", cv2.VideoWriter_fourcc(*'XVID'), fps, size)
number = step
dets = []
name_list = []
font = ImageFont.truetype("font.ttf", 22)
priorbox = PriorBox(cfg, image_size=(detect_h, detect_w))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
scale = torch.Tensor([detect_w, detect_h, detect_w, detect_h])
scale = scale.to(device)
scale1 = torch.Tensor([detect_w, detect_h, detect_w, detect_h,
detect_w, detect_h, detect_w, detect_h,
detect_w, detect_h])
scale1 = scale1.to(device)
src1 = np.array([
[38.3814, 51.6963],
[73.6186, 51.5014],
[56.1120, 71.7366],
[41.6361, 92.3655],
[70.8167, 92.2041]], dtype=np.float32)
tform = trans.SimilarityTransform()
while ret:
tic_all = time.time()
if number == step:
tic = time.time()
img = np.float32(frame_detect)
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)
loc, conf, landms = net(img) # forward pass
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
landms = landms * scale1 / resize
landms = landms.cpu().numpy()
# ignore low scores
inds = np.where(scores > args.confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:args.top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, args.nms_threshold)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
dets = dets[:args.keep_top_k, :]
landms = landms[:args.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
face_list = []
name_list = []
print('net forward time: {:.4f}'.format(time.time() - tic))
start_time_findall = time.time()
for i, det in enumerate(dets[:4]):
if det[4] < args.vis_thres:
continue
#boxes, score = det[:4], det[4]
dst = np.reshape(landms[i], (5, 2))
dst = dst * arf
tform.estimate(dst, src1)
M = tform.params[0:2, :]
frame2 = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
img112 = frame2[0:112, 0:112, :]
face_list.append(img112)
if len(face_list) != 0:
face_list = np.array(face_list)
face_list = face_list.transpose((0, 3, 1, 2))
face_list = np.array(face_list, dtype=np.float32)
face_list -= 127.5
face_list /= 127.5
print(face_list.shape)
print("warpALL time: " + str(time.time() - start_time_findall ))
#start_time = time.time()
name_list = findAll(face_list, arcface_model, index ,database_name_list, k_v, "cpu" if args.cpu else "cuda")
#print(name_list)
#print("findOneframe time: " + str(time.time() - start_time_findall))
# start_time = time.time()
# if (len(dets) != 0):
# for i, det in enumerate(dets[:]):
# if det[4] < args.vis_thres:
# continue
# boxes, score = det[:4], det[4]
# boxes = boxes * arf
# name = name_list[i]
# cv2.rectangle(frame, (int(boxes[0]), int(boxes[1])), (int(boxes[2]), int(boxes[3])), (255, 0, 0), 2)
# cv2.putText(frame, name, (int(boxes[0]), int(boxes[1])), cv2.FONT_HERSHEY_COMPLEX, 0.4,(0, 225, 255), 1)
start_time = time.time()
if(len(dets) != 0):
img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_PIL)
for i, det in enumerate(dets[:4]):
if det[4] < args.vis_thres:
continue
boxes, score = det[:4], det[4]
boxes = boxes * arf
name = name_list[i]
if not isinstance(name, np.unicode):
name = name.decode('utf8')
draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green", width=3)
frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
pipe.stdin.write(frame.tostring())
#out.write(frame)
print("drawOneframe time: " + str(time.time() - start_time))
start_time = time.time()
ret, frame = cap.read()
frame_detect = frame
number = 0
if (ret != 0 and factor != 0):
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
if (ret != 0 and factor2 != 0):
frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2)))
print("readframe time: " + str(time.time() - start_time))
else:
number += 1
# if (len(dets) != 0):
# for i, det in enumerate(dets[:4]):
# if det[4] < args.vis_thres:
# continue
# boxes, score = det[:4], det[4]
# cv2.rectangle(frame, (int(boxes[0]), int(boxes[1])), (int(boxes[2]), int(boxes[3])), (2, 255, 0), 1)
if (len(dets) != 0):
img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_PIL)
for i, det in enumerate(dets[:4]):
if det[4] < args.vis_thres:
continue
boxes, score = det[:4], det[4]
boxes = boxes * arf
name = name_list[i]
if not isinstance(name, np.unicode):
name = name.decode('utf8')
draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green",
width=3)
frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
start_time = time.time()
pipe.stdin.write(frame.tostring())
#out.write(frame)
print("writeframe time: " + str(time.time() - start_time))
start_time = time.time()
ret, frame = cap.read()
frame_detect = frame
if (ret != 0 and factor != 0):
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
if (ret != 0 and factor2 != 0):
frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2)))
print("readframe time: " + str(time.time() - start_time))
print('all time: {:.4f}'.format(time.time() - tic_all))
cap.release()
#out.release()
pipe.terminate()
print('total time: {:.4f}'.format(time.time() - tic_total))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--rtsp",
type=str,
default="",
dest="rtsp_path"
)
args = parser.parse_args()
cpu_or_cuda = "cuda" if torch.cuda.is_available() else "cpu"
# 加载人脸识别模型
arcface_model = load_arcface_model("./model/backbone100.pth", cpu_or_cuda="cuda")
# 加载人脸检测模型
retinaface_args = set_retinaface_conf(cpu_or_cuda=cpu_or_cuda)
retinaface_model = load_retinaface_model(retinaface_args)
k_v = load_npy("./Database/student.npy")
#print(list(k_v.keys()))
database_name_list = list(k_v.keys())
vector_list = np.array(list(k_v.values()))
print(vector_list.shape)
nlist = 10
quantizer = faiss.IndexFlatL2(512) # the other index
index = faiss.IndexIVFFlat(quantizer, 512, nlist, faiss.METRIC_L2)
index.train(vector_list)
#index = faiss.IndexFlatL2(512)
index.add(vector_list)
index.nprobe=10
detect_rtsp(args.rtsp_path, 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, index ,database_name_list, k_v, retinaface_args)
#detect_rtsp("rtsp://admin:2020@uestc@192.168.14.32:8557/h264", 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, index ,database_name_list, k_v, retinaface_args)
#detect_rtsp("cut.mp4", 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, k_v, retinaface_args)