更新Demo

main
PEAandDA 2024-07-22 18:18:12 +08:00
parent 39da5d7df5
commit 7121d24a90
3 changed files with 89 additions and 33 deletions

View File

@ -1,58 +1,98 @@
from ultralytics import YOLO
import cv2
import torch
class FaceDetector:
def __init__(self, cascade_path='haarcascade_frontalface_default.xml'):
self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_path)
def __init__(self, model_path='../weights/yolov10s_face.pt'):
try:
self.model = YOLO(model_path)
except FileNotFoundError:
print("ERROR: Could not load the YOLO model")
exit()
self.class_names_dict = self.model.model.names
def find_faces(self, img, draw=True):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = self.face_cascade.detectMultiScale(gray, 1.3, 5)
def find_faces(self, img):
original_img = img.copy()
results = self.model(img, verbose=False)[0]
detections = results.boxes.data
face_detections = []
other_detections = []
bboxs = []
for detection in detections:
x1, y1, x2, y2, confidence, class_id = detection
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
class_id = int(class_id)
# Convert to xywh format
w = x2 - x1
h = y2 - y1
bboxs.append([class_id, x1, y1, w, h, confidence])
if self.class_names_dict[class_id] == 'face':
face_detections.append((x1, y1, x2, y2, confidence))
else:
other_detections.append((x1, y1, x2, y2, class_id))
bboxs = []
for i, (x, y, w, h) in enumerate(faces):
bbox = (x, y, w, h)
bboxs.append([i, bbox, None]) # 'None' 用于保持与原代码结构一致因为OpenCV不提供置信度分数
if draw:
img = self.fancy_draw(img, bbox)
# Find the largest face
if face_detections:
largest_face = max(face_detections, key=lambda x: (x[2] - x[0]) * (x[3] - x[1]))
bboxs.append(largest_face)
original_img = self.draw_face(original_img, largest_face)
return img, bboxs
def fancy_draw(self, img, bbox, l=30, t=5, rt=1):
x, y, w, h = bbox
x1, y1 = x + w, y + h
# 绘制矩形
cv2.rectangle(img, bbox, (0, 255, 255), rt)
# Modify pixel values for other facial features
for x1, y1, x2, y2, _ in other_detections:
img[y1:y2, x1:x2] = 125
# 绘制角落
# 左上角
cv2.line(img, (x, y), (x+l, y), (255, 0, 255), t)
cv2.line(img, (x, y), (x, y+l), (255, 0, 255), t)
# 右上角
cv2.line(img, (x1, y), (x1-l, y), (255, 0, 255), t)
cv2.line(img, (x1, y), (x1, y+l), (255, 0, 255), t)
# 左下角
cv2.line(img, (x, y1), (x+l, y1), (255, 0, 255), t)
cv2.line(img, (x, y1), (x, y1-l), (255, 0, 255), t)
# 右下角
cv2.line(img, (x1, y1), (x1-l, y1), (255, 0, 255), t)
cv2.line(img, (x1, y1), (x1, y1-l), (255, 0, 255), t)
return original_img, img, bboxs
def draw_face(self, img, face, l=30, t=2, rt=1):
x1, y1, x2, y2, confidence = face
# Draw rectangle
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 255), rt)
# Draw label
# label = f"Face {confidence:.2f}"
# cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
# Draw corners
# Top left
cv2.line(img, (x1, y1), (x1 + l, y1), (255, 0, 255), t)
cv2.line(img, (x1, y1), (x1, y1 + l), (255, 0, 255), t)
# Top right
cv2.line(img, (x2, y1), (x2 - l, y1), (255, 0, 255), t)
cv2.line(img, (x2, y1), (x2, y1 + l), (255, 0, 255), t)
# Bottom left
cv2.line(img, (x1, y2), (x1 + l, y2), (255, 0, 255), t)
cv2.line(img, (x1, y2), (x1, y2 - l), (255, 0, 255), t)
# Bottom right
cv2.line(img, (x2, y2), (x2 - l, y2), (255, 0, 255), t)
cv2.line(img, (x2, y2), (x2, y2 - l), (255, 0, 255), t)
return img
# 使用示例
if __name__ == "__main__":
print(torch.cuda.is_available())
detector = FaceDetector()
cap = cv2.VideoCapture(0) # 使用默认摄像头
while True:
success, img = cap.read()
img, bboxs = detector.find_faces(img)
success, frame = cap.read()
iii, img, bboxs = detector.find_faces(frame)
cv2.imshow("Image", img)
cv2.imshow("Image", iii)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
cv2.destroyAllWindows()

16
demo/requirements.txt Normal file
View File

@ -0,0 +1,16 @@
PyQt5==5.15.9
matplotlib==3.8.3
numpy==1.24.3
opencv-contrib-python==4.9.0.80
opencv-python==4.9.0.80
opencv-python-headless==4.9.0.80
pillow==10.3.0
tqdm==4.66.2
scikit-learn==1.4.1.post1
scipy==1.12.0
torch --index-url https://download.pytorch.org/whl/cu118
torchvision --index-url https://download.pytorch.org/whl/cu118
torchaudio --index-url https://download.pytorch.org/whl/cu118
ultralytics==8.2.60
PyWavelets==1.4.1
charset-normalizer==3.1.0

Binary file not shown.