import argparse import subprocess import time import cv2 import torch import numpy as np from skimage import transform as trans from PIL import Image, ImageDraw, ImageFont from data import cfg_mnet, cfg_re50 from face_api import load_arcface_model, load_npy from layers.functions.prior_box import PriorBox from retinaface_detect import set_retinaface_conf, load_retinaface_model, findAll from utils.nms.py_cpu_nms import py_cpu_nms from utils.box_utils import decode, decode_landm import faiss ppi = 1280 ppi2 = 1100 step = 3 def detect_rtsp(rtsp, out_rtsp, net, arcface_model, index ,database_name_list, k_v, args): tic_total = time.time() cfg = None if args.network == "mobile0.25": cfg = cfg_mnet elif args.network == "resnet50": cfg = cfg_re50 device = torch.device("cpu" if args.cpu else "cuda") resize = 1 # testing begin cap = cv2.VideoCapture(rtsp) ret, frame = cap.read() h, w = frame.shape[:2] factor = 0 if (w > ppi): factor = h / w frame = cv2.resize(frame, (ppi, int(ppi * factor))) h, w = frame.shape[:2] arf = 1 detect_h, detect_w = frame.shape[:2] frame_detect = frame factor2 = 0 if (w > ppi2): factor2 = h / w frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2))) detect_h, detect_w = frame_detect.shape[:2] arf = w/detect_w print(w,h) print(detect_w,detect_h) fps = cap.get(cv2.CAP_PROP_FPS) #print(fps) size = (w, h) sizeStr = str(size[0]) + 'x' + str(size[1]) if(out_rtsp.startswith("rtsp")): command = ['ffmpeg', '-y', '-an', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-pix_fmt', 'bgr24', '-s', sizeStr, '-r', "25", '-i', '-', '-c:v', 'libx265', '-b:v', '3000k', '-pix_fmt', 'yuv420p', '-preset', 'ultrafast', '-f', 'rtsp', out_rtsp] pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE) #out = cv2.VideoWriter("output.avi", cv2.VideoWriter_fourcc(*'XVID'), fps, size) number = step dets = [] name_list = [] font = ImageFont.truetype("font.ttf", 22) priorbox = PriorBox(cfg, image_size=(detect_h, detect_w)) priors = priorbox.forward() priors = priors.to(device) prior_data = priors.data scale = torch.Tensor([detect_w, detect_h, detect_w, detect_h]) scale = scale.to(device) scale1 = torch.Tensor([detect_w, detect_h, detect_w, detect_h, detect_w, detect_h, detect_w, detect_h, detect_w, detect_h]) scale1 = scale1.to(device) src1 = np.array([ [38.3814, 51.6963], [73.6186, 51.5014], [56.1120, 71.7366], [41.6361, 92.3655], [70.8167, 92.2041]], dtype=np.float32) tform = trans.SimilarityTransform() while ret: tic_all = time.time() if number == step: tic = time.time() img = np.float32(frame_detect) img -= (104, 117, 123) img = img.transpose(2, 0, 1) img = torch.from_numpy(img).unsqueeze(0) img = img.to(device) loc, conf, landms = net(img) # forward pass boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) boxes = boxes * scale / resize boxes = boxes.cpu().numpy() scores = conf.squeeze(0).data.cpu().numpy()[:, 1] landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance']) landms = landms * scale1 / resize landms = landms.cpu().numpy() # ignore low scores inds = np.where(scores > args.confidence_threshold)[0] boxes = boxes[inds] landms = landms[inds] scores = scores[inds] # keep top-K before NMS order = scores.argsort()[::-1][:args.top_k] boxes = boxes[order] landms = landms[order] scores = scores[order] # do NMS dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) keep = py_cpu_nms(dets, args.nms_threshold) # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu) dets = dets[keep, :] landms = landms[keep] # keep top-K faster NMS dets = dets[:args.keep_top_k, :] landms = landms[:args.keep_top_k, :] dets = np.concatenate((dets, landms), axis=1) face_list = [] name_list = [] print('net forward time: {:.4f}'.format(time.time() - tic)) start_time_findall = time.time() for i, det in enumerate(dets[:4]): if det[4] < args.vis_thres: continue #boxes, score = det[:4], det[4] dst = np.reshape(landms[i], (5, 2)) dst = dst * arf tform.estimate(dst, src1) M = tform.params[0:2, :] frame2 = cv2.warpAffine(frame, M, (w, h), borderValue=0.0) img112 = frame2[0:112, 0:112, :] face_list.append(img112) if len(face_list) != 0: face_list = np.array(face_list) face_list = face_list.transpose((0, 3, 1, 2)) face_list = np.array(face_list, dtype=np.float32) face_list -= 127.5 face_list /= 127.5 print(face_list.shape) print("warpALL time: " + str(time.time() - start_time_findall )) #start_time = time.time() name_list = findAll(face_list, arcface_model, index ,database_name_list, k_v, "cpu" if args.cpu else "cuda") #print(name_list) #print("findOneframe time: " + str(time.time() - start_time_findall)) # start_time = time.time() # if (len(dets) != 0): # for i, det in enumerate(dets[:]): # if det[4] < args.vis_thres: # continue # boxes, score = det[:4], det[4] # boxes = boxes * arf # name = name_list[i] # cv2.rectangle(frame, (int(boxes[0]), int(boxes[1])), (int(boxes[2]), int(boxes[3])), (255, 0, 0), 2) # cv2.putText(frame, name, (int(boxes[0]), int(boxes[1])), cv2.FONT_HERSHEY_COMPLEX, 0.4,(0, 225, 255), 1) start_time = time.time() if(len(dets) != 0): img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(img_PIL) for i, det in enumerate(dets[:4]): if det[4] < args.vis_thres: continue boxes, score = det[:4], det[4] boxes = boxes * arf name = name_list[i] if not isinstance(name, np.unicode): name = name.decode('utf8') draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font) draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green", width=3) frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR) pipe.stdin.write(frame.tostring()) #out.write(frame) print("drawOneframe time: " + str(time.time() - start_time)) start_time = time.time() ret, frame = cap.read() frame_detect = frame number = 0 if (ret != 0 and factor != 0): frame = cv2.resize(frame, (ppi, int(ppi * factor))) if (ret != 0 and factor2 != 0): frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2))) print("readframe time: " + str(time.time() - start_time)) else: number += 1 # if (len(dets) != 0): # for i, det in enumerate(dets[:4]): # if det[4] < args.vis_thres: # continue # boxes, score = det[:4], det[4] # cv2.rectangle(frame, (int(boxes[0]), int(boxes[1])), (int(boxes[2]), int(boxes[3])), (2, 255, 0), 1) if (len(dets) != 0): img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(img_PIL) for i, det in enumerate(dets[:4]): if det[4] < args.vis_thres: continue boxes, score = det[:4], det[4] boxes = boxes * arf name = name_list[i] if not isinstance(name, np.unicode): name = name.decode('utf8') draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font) draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green", width=3) frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR) start_time = time.time() pipe.stdin.write(frame.tostring()) #out.write(frame) print("writeframe time: " + str(time.time() - start_time)) start_time = time.time() ret, frame = cap.read() frame_detect = frame if (ret != 0 and factor != 0): frame = cv2.resize(frame, (ppi, int(ppi * factor))) if (ret != 0 and factor2 != 0): frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2))) print("readframe time: " + str(time.time() - start_time)) print('all time: {:.4f}'.format(time.time() - tic_all)) cap.release() #out.release() pipe.terminate() print('total time: {:.4f}'.format(time.time() - tic_total)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--rtsp", type=str, default="", dest="rtsp_path" ) args = parser.parse_args() cpu_or_cuda = "cuda" if torch.cuda.is_available() else "cpu" # 加载人脸识别模型 arcface_model = load_arcface_model("./model/backbone100.pth", cpu_or_cuda="cuda") # 加载人脸检测模型 retinaface_args = set_retinaface_conf(cpu_or_cuda=cpu_or_cuda) retinaface_model = load_retinaface_model(retinaface_args) k_v = load_npy("./Database/student.npy") #print(list(k_v.keys())) database_name_list = list(k_v.keys()) vector_list = np.array(list(k_v.values())) print(vector_list.shape) nlist = 10 quantizer = faiss.IndexFlatL2(512) # the other index index = faiss.IndexIVFFlat(quantizer, 512, nlist, faiss.METRIC_L2) index.train(vector_list) #index = faiss.IndexFlatL2(512) index.add(vector_list) index.nprobe=10 detect_rtsp(args.rtsp_path, 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, index ,database_name_list, k_v, retinaface_args) #detect_rtsp("rtsp://admin:2020@uestc@192.168.14.32:8557/h264", 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, index ,database_name_list, k_v, retinaface_args) #detect_rtsp("cut.mp4", 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, k_v, retinaface_args)