150 lines
5.7 KiB
Python
150 lines
5.7 KiB
Python
|
#增加中心点并解决人头框闪烁问题,但目前依旧只能检索一人,下一步需要解决多目标问题
|
||
|
|
||
|
import os
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
from collections import deque
|
||
|
from ultralytics import YOLO
|
||
|
|
||
|
# YOLOv8模型路径
|
||
|
model_path = r'detect\train\weights\best.pt'
|
||
|
model = YOLO(model_path)
|
||
|
|
||
|
# 动态检测参数
|
||
|
tracking_window_size = 250 # 10秒对应的帧数
|
||
|
center_history = deque(maxlen=tracking_window_size)
|
||
|
tracker = None # 跟踪器变量
|
||
|
tracking_initialized = False
|
||
|
|
||
|
# 定义头部中心点的容忍范围(像素)
|
||
|
tolerance_radius = 20 # 你可以根据实际需要调整
|
||
|
detection_interval = 10 # 目标检测的帧间隔
|
||
|
|
||
|
def infer_and_draw_video(video_path, output_folder):
|
||
|
# 打开RTSP视频流
|
||
|
cap = cv2.VideoCapture(video_path)
|
||
|
if not cap.isOpened():
|
||
|
print("错误:无法打开视频流。")
|
||
|
return
|
||
|
|
||
|
# 获取视频属性
|
||
|
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
|
||
|
# 创建视频写入对象
|
||
|
output_video_path = os.path.join(output_folder, 'output_video.mp4')
|
||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用mp4v编解码器
|
||
|
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
||
|
|
||
|
global tracking_initialized
|
||
|
global center_history
|
||
|
global tracker
|
||
|
|
||
|
frame_counter = 0 # 帧计数器
|
||
|
last_detection_boxes = [] # 上一帧检测到的框
|
||
|
|
||
|
while True:
|
||
|
ret, frame = cap.read()
|
||
|
if not ret:
|
||
|
print("视频流读取结束或出错。")
|
||
|
break
|
||
|
|
||
|
frame_counter += 1
|
||
|
|
||
|
filtered_boxes = [] # 确保在进入目标检测逻辑之前初始化
|
||
|
|
||
|
if frame_counter % detection_interval == 0: # 每一定帧数进行目标检测
|
||
|
# 使用YOLOv8进行目标检测
|
||
|
results = model(frame)
|
||
|
detected_boxes = []
|
||
|
if results:
|
||
|
for result in results:
|
||
|
if result.boxes is not None and len(result.boxes.xyxy) > 0:
|
||
|
boxes = result.boxes.xyxy.cpu().numpy()
|
||
|
confidences = result.boxes.conf.cpu().numpy()
|
||
|
|
||
|
for i, box in enumerate(boxes):
|
||
|
x1, y1, x2, y2 = map(int, box[:4])
|
||
|
conf = confidences[i] if len(confidences) > i else 0.0
|
||
|
detected_boxes.append((x1, y1, x2, y2, conf))
|
||
|
|
||
|
filtered_boxes = filter_and_merge_boxes(detected_boxes)
|
||
|
|
||
|
if filtered_boxes:
|
||
|
# 选择置信度最高的目标进行跟踪
|
||
|
filtered_boxes = sorted(filtered_boxes, key=lambda b: b[4], reverse=True)
|
||
|
x1, y1, x2, y2, _ = filtered_boxes[0]
|
||
|
tracking_bbox = (x1, y1, x2 - x1, y2 - y1)
|
||
|
tracker = cv2.TrackerCSRT_create()
|
||
|
tracker.init(frame, tracking_bbox)
|
||
|
tracking_initialized = True
|
||
|
last_detection_boxes = filtered_boxes
|
||
|
else:
|
||
|
# 如果没有检测到框,则重置跟踪器
|
||
|
tracking_initialized = False
|
||
|
elif tracking_initialized:
|
||
|
# 更新跟踪器
|
||
|
success, bbox = tracker.update(frame)
|
||
|
if success:
|
||
|
x, y, w, h = map(int, bbox)
|
||
|
center = (x + w // 2, y + h // 2)
|
||
|
center_history.append(center)
|
||
|
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||
|
cv2.circle(frame, center, 5, (255, 0, 0), -1)
|
||
|
|
||
|
# 检查中心点是否稳定在容忍范围内
|
||
|
if len(center_history) == tracking_window_size:
|
||
|
initial_center = center_history[0]
|
||
|
stable = all(np.linalg.norm(np.array(center) - np.array(initial_center)) <= tolerance_radius for center in center_history)
|
||
|
if stable:
|
||
|
cv2.putText(frame, "SLEEP", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
|
||
|
else:
|
||
|
# 跟踪失败,重置跟踪器
|
||
|
tracking_initialized = False
|
||
|
|
||
|
# 如果跟踪失败但有检测到的目标框,使用检测到的目标框
|
||
|
if not tracking_initialized and last_detection_boxes:
|
||
|
for box in last_detection_boxes:
|
||
|
x1, y1, x2, y2, _ = box
|
||
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
|
center = (x1 + (x2 - x1) // 2, y1 + (y2 - y1) // 2)
|
||
|
cv2.circle(frame, center, 5, (255, 0, 0), -1)
|
||
|
|
||
|
# 写入处理后的帧
|
||
|
out.write(frame)
|
||
|
|
||
|
cap.release()
|
||
|
out.release()
|
||
|
cv2.destroyAllWindows()
|
||
|
print(f"已保存带注释的视频到: {output_video_path}")
|
||
|
|
||
|
def filter_and_merge_boxes(boxes):
|
||
|
filtered_boxes = []
|
||
|
threshold = 0.5 # IOU阈值
|
||
|
|
||
|
def iou(box1, box2):
|
||
|
x1, y1, x2, y2 = box1
|
||
|
x1_, y1_, x2_, y2_ = box2
|
||
|
ix1, iy1 = max(x1, x1_), max(y1, y1_)
|
||
|
ix2, iy2 = min(x2, x2_), min(y2, y2_)
|
||
|
iw = max(ix2 - ix1 + 1, 0)
|
||
|
ih = max(iy2 - iy1 + 1, 0)
|
||
|
inter = iw * ih
|
||
|
ua = (x2 - x1 + 1) * (y2 - y1 + 1) + (x2_ - x1_ + 1) * (y2_ - y1_) - inter
|
||
|
return inter / ua
|
||
|
|
||
|
for i, box1 in enumerate(boxes):
|
||
|
keep = True
|
||
|
for j, box2 in enumerate(filtered_boxes):
|
||
|
if iou(box1[:4], box2[:4]) > threshold:
|
||
|
keep = False
|
||
|
break
|
||
|
if keep:
|
||
|
filtered_boxes.append(box1)
|
||
|
|
||
|
return filtered_boxes
|
||
|
|
||
|
# 使用实际视频路径进行推理,并指定输出文件夹
|
||
|
infer_and_draw_video(r'视频路径', r'输出路径')
|
||
|
#infer_and_draw_video(r'摄像头网络串流', r'输出路径')
|