283 lines
11 KiB
Python
283 lines
11 KiB
Python
import argparse
|
|
import subprocess
|
|
import time
|
|
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from skimage import transform as trans
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
from data import cfg_mnet, cfg_re50
|
|
from face_api import load_arcface_model, load_npy
|
|
from layers.functions.prior_box import PriorBox
|
|
from retinaface_detect import set_retinaface_conf, load_retinaface_model, findAll
|
|
from utils.nms.py_cpu_nms import py_cpu_nms
|
|
from utils.box_utils import decode, decode_landm
|
|
import faiss
|
|
|
|
ppi = 1280
|
|
ppi2 = 1100
|
|
step = 3
|
|
|
|
def detect_rtsp(rtsp, out_rtsp, net, arcface_model, index ,database_name_list, k_v, args):
|
|
tic_total = time.time()
|
|
cfg = None
|
|
if args.network == "mobile0.25":
|
|
cfg = cfg_mnet
|
|
elif args.network == "resnet50":
|
|
cfg = cfg_re50
|
|
device = torch.device("cpu" if args.cpu else "cuda")
|
|
resize = 1
|
|
|
|
# testing begin
|
|
cap = cv2.VideoCapture(rtsp)
|
|
ret, frame = cap.read()
|
|
h, w = frame.shape[:2]
|
|
|
|
factor = 0
|
|
if (w > ppi):
|
|
factor = h / w
|
|
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
|
|
h, w = frame.shape[:2]
|
|
arf = 1
|
|
detect_h, detect_w = frame.shape[:2]
|
|
frame_detect = frame
|
|
factor2 = 0
|
|
if (w > ppi2):
|
|
factor2 = h / w
|
|
frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2)))
|
|
detect_h, detect_w = frame_detect.shape[:2]
|
|
arf = w/detect_w
|
|
print(w,h)
|
|
print(detect_w,detect_h)
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
#print(fps)
|
|
size = (w, h)
|
|
sizeStr = str(size[0]) + 'x' + str(size[1])
|
|
if(out_rtsp.startswith("rtsp")):
|
|
command = ['ffmpeg',
|
|
'-y', '-an',
|
|
'-f', 'rawvideo',
|
|
'-vcodec', 'rawvideo',
|
|
'-pix_fmt', 'bgr24',
|
|
'-s', sizeStr,
|
|
'-r', "25",
|
|
'-i', '-',
|
|
'-c:v', 'libx265',
|
|
'-b:v', '3000k',
|
|
'-pix_fmt', 'yuv420p',
|
|
'-preset', 'ultrafast',
|
|
'-f', 'rtsp',
|
|
out_rtsp]
|
|
pipe = subprocess.Popen(command, shell=False, stdin=subprocess.PIPE)
|
|
#out = cv2.VideoWriter("output.avi", cv2.VideoWriter_fourcc(*'XVID'), fps, size)
|
|
number = step
|
|
dets = []
|
|
name_list = []
|
|
font = ImageFont.truetype("font.ttf", 22)
|
|
priorbox = PriorBox(cfg, image_size=(detect_h, detect_w))
|
|
priors = priorbox.forward()
|
|
priors = priors.to(device)
|
|
prior_data = priors.data
|
|
|
|
scale = torch.Tensor([detect_w, detect_h, detect_w, detect_h])
|
|
scale = scale.to(device)
|
|
scale1 = torch.Tensor([detect_w, detect_h, detect_w, detect_h,
|
|
detect_w, detect_h, detect_w, detect_h,
|
|
detect_w, detect_h])
|
|
scale1 = scale1.to(device)
|
|
|
|
src1 = np.array([
|
|
[38.3814, 51.6963],
|
|
[73.6186, 51.5014],
|
|
[56.1120, 71.7366],
|
|
[41.6361, 92.3655],
|
|
[70.8167, 92.2041]], dtype=np.float32)
|
|
tform = trans.SimilarityTransform()
|
|
|
|
while ret:
|
|
tic_all = time.time()
|
|
if number == step:
|
|
tic = time.time()
|
|
img = np.float32(frame_detect)
|
|
img -= (104, 117, 123)
|
|
img = img.transpose(2, 0, 1)
|
|
img = torch.from_numpy(img).unsqueeze(0)
|
|
img = img.to(device)
|
|
|
|
loc, conf, landms = net(img) # forward pass
|
|
|
|
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
|
|
boxes = boxes * scale / resize
|
|
boxes = boxes.cpu().numpy()
|
|
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
|
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
|
|
|
|
landms = landms * scale1 / resize
|
|
landms = landms.cpu().numpy()
|
|
|
|
# ignore low scores
|
|
inds = np.where(scores > args.confidence_threshold)[0]
|
|
boxes = boxes[inds]
|
|
landms = landms[inds]
|
|
scores = scores[inds]
|
|
|
|
# keep top-K before NMS
|
|
order = scores.argsort()[::-1][:args.top_k]
|
|
boxes = boxes[order]
|
|
landms = landms[order]
|
|
scores = scores[order]
|
|
|
|
# do NMS
|
|
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
|
keep = py_cpu_nms(dets, args.nms_threshold)
|
|
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
|
|
dets = dets[keep, :]
|
|
landms = landms[keep]
|
|
|
|
# keep top-K faster NMS
|
|
dets = dets[:args.keep_top_k, :]
|
|
landms = landms[:args.keep_top_k, :]
|
|
|
|
dets = np.concatenate((dets, landms), axis=1)
|
|
face_list = []
|
|
name_list = []
|
|
print('net forward time: {:.4f}'.format(time.time() - tic))
|
|
start_time_findall = time.time()
|
|
for i, det in enumerate(dets[:4]):
|
|
if det[4] < args.vis_thres:
|
|
continue
|
|
#boxes, score = det[:4], det[4]
|
|
dst = np.reshape(landms[i], (5, 2))
|
|
dst = dst * arf
|
|
|
|
tform.estimate(dst, src1)
|
|
M = tform.params[0:2, :]
|
|
frame2 = cv2.warpAffine(frame, M, (w, h), borderValue=0.0)
|
|
img112 = frame2[0:112, 0:112, :]
|
|
face_list.append(img112)
|
|
|
|
if len(face_list) != 0:
|
|
face_list = np.array(face_list)
|
|
face_list = face_list.transpose((0, 3, 1, 2))
|
|
face_list = np.array(face_list, dtype=np.float32)
|
|
face_list -= 127.5
|
|
face_list /= 127.5
|
|
print(face_list.shape)
|
|
print("warpALL time: " + str(time.time() - start_time_findall ))
|
|
#start_time = time.time()
|
|
name_list = findAll(face_list, arcface_model, index ,database_name_list, k_v, "cpu" if args.cpu else "cuda")
|
|
#print(name_list)
|
|
|
|
#print("findOneframe time: " + str(time.time() - start_time_findall))
|
|
# start_time = time.time()
|
|
# if (len(dets) != 0):
|
|
# for i, det in enumerate(dets[:]):
|
|
# if det[4] < args.vis_thres:
|
|
# continue
|
|
# boxes, score = det[:4], det[4]
|
|
# boxes = boxes * arf
|
|
# name = name_list[i]
|
|
# cv2.rectangle(frame, (int(boxes[0]), int(boxes[1])), (int(boxes[2]), int(boxes[3])), (255, 0, 0), 2)
|
|
# cv2.putText(frame, name, (int(boxes[0]), int(boxes[1])), cv2.FONT_HERSHEY_COMPLEX, 0.4,(0, 225, 255), 1)
|
|
start_time = time.time()
|
|
if(len(dets) != 0):
|
|
img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
draw = ImageDraw.Draw(img_PIL)
|
|
for i, det in enumerate(dets[:4]):
|
|
if det[4] < args.vis_thres:
|
|
continue
|
|
boxes, score = det[:4], det[4]
|
|
boxes = boxes * arf
|
|
name = name_list[i]
|
|
if not isinstance(name, np.unicode):
|
|
name = name.decode('utf8')
|
|
draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
|
|
draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green", width=3)
|
|
frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
|
|
pipe.stdin.write(frame.tostring())
|
|
#out.write(frame)
|
|
print("drawOneframe time: " + str(time.time() - start_time))
|
|
start_time = time.time()
|
|
ret, frame = cap.read()
|
|
frame_detect = frame
|
|
number = 0
|
|
if (ret != 0 and factor != 0):
|
|
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
|
|
if (ret != 0 and factor2 != 0):
|
|
frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2)))
|
|
print("readframe time: " + str(time.time() - start_time))
|
|
else:
|
|
number += 1
|
|
# if (len(dets) != 0):
|
|
# for i, det in enumerate(dets[:4]):
|
|
# if det[4] < args.vis_thres:
|
|
# continue
|
|
# boxes, score = det[:4], det[4]
|
|
# cv2.rectangle(frame, (int(boxes[0]), int(boxes[1])), (int(boxes[2]), int(boxes[3])), (2, 255, 0), 1)
|
|
if (len(dets) != 0):
|
|
img_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
draw = ImageDraw.Draw(img_PIL)
|
|
for i, det in enumerate(dets[:4]):
|
|
if det[4] < args.vis_thres:
|
|
continue
|
|
boxes, score = det[:4], det[4]
|
|
boxes = boxes * arf
|
|
name = name_list[i]
|
|
if not isinstance(name, np.unicode):
|
|
name = name.decode('utf8')
|
|
draw.text((int(boxes[0]), int(boxes[1])), name, fill=(255, 0, 0), font=font)
|
|
draw.rectangle((int(boxes[0]), int(boxes[1]), int(boxes[2]), int(boxes[3])), outline="green",
|
|
width=3)
|
|
frame = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
|
|
start_time = time.time()
|
|
pipe.stdin.write(frame.tostring())
|
|
#out.write(frame)
|
|
print("writeframe time: " + str(time.time() - start_time))
|
|
start_time = time.time()
|
|
ret, frame = cap.read()
|
|
frame_detect = frame
|
|
if (ret != 0 and factor != 0):
|
|
frame = cv2.resize(frame, (ppi, int(ppi * factor)))
|
|
if (ret != 0 and factor2 != 0):
|
|
frame_detect = cv2.resize(frame, (ppi2, int(ppi2 * factor2)))
|
|
print("readframe time: " + str(time.time() - start_time))
|
|
print('all time: {:.4f}'.format(time.time() - tic_all))
|
|
cap.release()
|
|
#out.release()
|
|
pipe.terminate()
|
|
print('total time: {:.4f}'.format(time.time() - tic_total))
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--rtsp",
|
|
type=str,
|
|
default="",
|
|
dest="rtsp_path"
|
|
)
|
|
args = parser.parse_args()
|
|
cpu_or_cuda = "cuda" if torch.cuda.is_available() else "cpu"
|
|
# 加载人脸识别模型
|
|
arcface_model = load_arcface_model("./model/backbone100.pth", cpu_or_cuda="cuda")
|
|
# 加载人脸检测模型
|
|
retinaface_args = set_retinaface_conf(cpu_or_cuda=cpu_or_cuda)
|
|
retinaface_model = load_retinaface_model(retinaface_args)
|
|
k_v = load_npy("./Database/student.npy")
|
|
#print(list(k_v.keys()))
|
|
database_name_list = list(k_v.keys())
|
|
vector_list = np.array(list(k_v.values()))
|
|
print(vector_list.shape)
|
|
nlist = 10
|
|
quantizer = faiss.IndexFlatL2(512) # the other index
|
|
index = faiss.IndexIVFFlat(quantizer, 512, nlist, faiss.METRIC_L2)
|
|
index.train(vector_list)
|
|
#index = faiss.IndexFlatL2(512)
|
|
index.add(vector_list)
|
|
index.nprobe=10
|
|
|
|
detect_rtsp(args.rtsp_path, 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, index ,database_name_list, k_v, retinaface_args)
|
|
|
|
#detect_rtsp("rtsp://admin:2020@uestc@192.168.14.32:8557/h264", 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, index ,database_name_list, k_v, retinaface_args)
|
|
#detect_rtsp("cut.mp4", 'rtsp://localhost:5001/test2', retinaface_model, arcface_model, k_v, retinaface_args)
|