152 lines
5.7 KiB
Python
152 lines
5.7 KiB
Python
#解决多人头问题,但是cpu处理速度跟不上,只能运行一小段时间
|
||
|
||
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
from collections import deque
|
||
from ultralytics import YOLO
|
||
|
||
# YOLOv8模型路径
|
||
model_path = r'detect\train\weights\best.pt'
|
||
model = YOLO(model_path)
|
||
|
||
# 动态检测参数
|
||
tracking_window_size = 250 # 10秒对应的帧数
|
||
center_history = deque(maxlen=tracking_window_size)
|
||
tracking_initialized = False
|
||
tracker_list = [] # 跟踪器列表
|
||
|
||
# 定义头部中心点的容忍范围(像素)
|
||
tolerance_radius = 20 # 你可以根据实际需要调整
|
||
detection_interval = 10 # 目标检测的帧间隔
|
||
|
||
def infer_and_draw_video(video_path, output_folder):
|
||
global tracking_initialized
|
||
global center_history
|
||
global tracker_list
|
||
|
||
# 打开RTSP视频流
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
print("错误:无法打开视频流。")
|
||
return
|
||
|
||
# 获取视频属性
|
||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
||
# 创建视频写入对象
|
||
output_video_path = os.path.join(output_folder, 'output_video.mp4')
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用mp4v编解码器
|
||
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
||
|
||
frame_counter = 0 # 帧计数器
|
||
last_detection_boxes = [] # 上一帧检测到的框
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
print("视频流读取结束或出错。")
|
||
break
|
||
|
||
frame_counter += 1
|
||
|
||
if frame_counter % detection_interval == 0: # 每一定帧数进行目标检测
|
||
# 使用YOLOv8进行目标检测
|
||
results = model(frame)
|
||
detected_boxes = []
|
||
if results:
|
||
for result in results:
|
||
if result.boxes is not None and len(result.boxes.xyxy) > 0:
|
||
boxes = result.boxes.xyxy.cpu().numpy()
|
||
confidences = result.boxes.conf.cpu().numpy()
|
||
|
||
for i, box in enumerate(boxes):
|
||
x1, y1, x2, y2 = map(int, box[:4])
|
||
conf = confidences[i] if len(confidences) > i else 0.0
|
||
detected_boxes.append((x1, y1, x2, y2, conf))
|
||
|
||
filtered_boxes = filter_and_merge_boxes(detected_boxes)
|
||
|
||
if filtered_boxes:
|
||
# 对每个检测到的目标初始化一个跟踪器
|
||
tracker_list = [cv2.TrackerCSRT_create() for _ in filtered_boxes]
|
||
for i, box in enumerate(filtered_boxes):
|
||
x1, y1, x2, y2, _ = box
|
||
tracking_bbox = (x1, y1, x2 - x1, y2 - y1)
|
||
tracker_list[i].init(frame, tracking_bbox)
|
||
last_detection_boxes = filtered_boxes
|
||
tracking_initialized = True
|
||
else:
|
||
# 如果没有检测到框,则重置跟踪器
|
||
tracking_initialized = False
|
||
elif tracking_initialized:
|
||
# 更新所有跟踪器
|
||
for tracker in tracker_list:
|
||
success, bbox = tracker.update(frame)
|
||
if success:
|
||
x, y, w, h = map(int, bbox)
|
||
center = (x + w // 2, y + h // 2)
|
||
center_history.append(center)
|
||
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||
cv2.circle(frame, center, 5, (255, 0, 0), -1)
|
||
|
||
# 检查中心点是否稳定在容忍范围内
|
||
if len(center_history) == tracking_window_size:
|
||
initial_center = center_history[0]
|
||
stable = all(np.linalg.norm(np.array(center) - np.array(initial_center)) <= tolerance_radius for center in center_history)
|
||
if stable:
|
||
cv2.putText(frame, "SLEEP", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
|
||
else:
|
||
# 跟踪失败,重置跟踪器
|
||
tracking_initialized = False
|
||
tracker_list = []
|
||
|
||
# 如果跟踪失败但有检测到的目标框,显示检测框
|
||
if not tracking_initialized and last_detection_boxes:
|
||
for box in last_detection_boxes:
|
||
x1, y1, x2, y2, _ = box
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
center = (x1 + (x2 - x1) // 2, y1 + (y2 - y1) // 2)
|
||
cv2.circle(frame, center, 5, (255, 0, 0), -1)
|
||
|
||
# 写入处理后的帧
|
||
out.write(frame)
|
||
|
||
cap.release()
|
||
out.release()
|
||
cv2.destroyAllWindows()
|
||
print(f"已保存带注释的视频到: {output_video_path}")
|
||
|
||
def filter_and_merge_boxes(boxes):
|
||
filtered_boxes = []
|
||
threshold = 0.5 # IOU阈值
|
||
|
||
def iou(box1, box2):
|
||
x1, y1, x2, y2 = box1
|
||
x1_, y1_, x2_, y2_ = box2
|
||
ix1, iy1 = max(x1, x1_), max(y1, y1_)
|
||
ix2, iy2 = min(x2, x2_), min(y2, y2_)
|
||
iw = max(ix2 - ix1 + 1, 0)
|
||
ih = max(iy2 - iy1 + 1, 0)
|
||
inter = iw * ih
|
||
ua = (x2 - x1 + 1) * (y2 - y1 + 1) + (x2_ - x1_ + 1) * (y2_ - y1_) - inter
|
||
return inter / ua
|
||
|
||
for i, box1 in enumerate(boxes):
|
||
keep = True
|
||
for j, box2 in enumerate(filtered_boxes):
|
||
if iou(box1[:4], box2[:4]) > threshold:
|
||
keep = False
|
||
break
|
||
if keep:
|
||
filtered_boxes.append(box1)
|
||
|
||
return filtered_boxes
|
||
|
||
# 使用实际视频路径进行推理,并指定输出文件夹
|
||
infer_and_draw_video(r'视频路径', r'输出路径')
|
||
#infer_and_draw_video(r'摄像头网络串流', r'输出路径')
|