Sleeping-post-detection-fir.../fps04.py

126 lines
4.8 KiB
Python
Raw Permalink Normal View History

2024-09-08 14:08:46 +08:00
#优化了多人头代码,目前可以持续运作了
import os
import cv2
import numpy as np
from collections import deque
from fps02 import filter_and_merge_boxes
from fps03 import tolerance_radius
from ultralytics import YOLO
from concurrent.futures import ThreadPoolExecutor
# YOLOv8模型路径
model_path = r'etect\train\weights\best.pt'
model = YOLO(model_path)
# 动态检测参数
tracking_window_size = 250
center_history = deque(maxlen=tracking_window_size)
tracking_initialized = False
tracker_list = []
detection_interval = 10
new_width = 640
def infer_and_draw_video(video_path, output_folder):
global tracking_initialized
global center_history
global tracker_list
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("错误:无法打开视频流。")
return
fps = int(cap.get(cv2.CAP_PROP_FPS))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 设置新的分辨率
cap.set(cv2.CAP_PROP_FRAME_WIDTH, new_width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, int((new_width / cap.get(cv2.CAP_PROP_FRAME_WIDTH)) * height))
output_video_path = os.path.join(output_folder, 'output_video.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (new_width, int((new_width / cap.get(cv2.CAP_PROP_FRAME_WIDTH)) * height)))
frame_counter = 0
last_detection_boxes = []
def process_frame(frame, tracking_initialized=None):
nonlocal frame_counter
frame_counter += 1
if frame_counter % detection_interval == 0:
results = model(frame)
detected_boxes = []
if results:
for result in results:
if result.boxes is not None and len(result.boxes.xyxy) > 0:
boxes = result.boxes.xyxy.cpu().numpy()
confidences = result.boxes.conf.cpu().numpy()
for i, box in enumerate(boxes):
x1, y1, x2, y2 = map(int, box[:4])
conf = confidences[i] if len(confidences) > i else 0.0
detected_boxes.append((x1, y1, x2, y2, conf))
filtered_boxes = filter_and_merge_boxes(detected_boxes)
print(f"检测到的框:{filtered_boxes}")
if filtered_boxes:
tracker_list[:] = [cv2.TrackerKCF_create() for _ in filtered_boxes]
for i, box in enumerate(filtered_boxes):
x1, y1, x2, y2, _ = box
tracking_bbox = (x1, y1, x2 - x1, y2 - y1)
tracker_list[i].init(frame, tracking_bbox)
last_detection_boxes[:] = filtered_boxes
tracking_initialized = True
else:
tracking_initialized = False
elif tracking_initialized:
for tracker in tracker_list:
success, bbox = tracker.update(frame)
if success:
x, y, w, h = map(int, bbox)
center = (x + w // 2, y + h // 2)
center_history.append(center)
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.circle(frame, center, 5, (255, 0, 0), -1)
if len(center_history) == tracking_window_size:
initial_center = center_history[0]
stable = all(np.linalg.norm(np.array(center) - np.array(initial_center)) <= tolerance_radius for center in center_history)
if stable:
cv2.putText(frame, "SLEEP", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
else:
tracking_initialized = False
tracker_list.clear()
if not tracking_initialized and last_detection_boxes:
for box in last_detection_boxes:
x1, y1, x2, y2, _ = box
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
center = (x1 + (x2 - x1) // 2, y1 + (y2 - y1) // 2)
cv2.circle(frame, center, 5, (255, 0, 0), -1)
return frame
# 使用多线程处理视频流
with ThreadPoolExecutor(max_workers=2) as executor:
while True:
ret, frame = cap.read()
if not ret:
print("视频流读取结束或出错。")
break
processed_frame = executor.submit(process_frame, frame).result()
out.write(processed_frame)
cap.release()
out.release()
cv2.destroyAllWindows()
print(f"已保存带注释的视频到: {output_video_path}")
# 调用函数
#infer_and_draw_video(r'视频路径', r'输出路径')
infer_and_draw_video(r'网络串流', r'输出路径')