89 lines
3.0 KiB
Python
89 lines
3.0 KiB
Python
|
import cv2
|
||
|
import torch
|
||
|
from PIL import Image
|
||
|
from torchvision import transforms, models
|
||
|
import torch.nn as nn
|
||
|
|
||
|
class AgeGenderPredictor:
|
||
|
def __init__(self, model_path):
|
||
|
self.model = self.load_model(model_path)
|
||
|
self.gender_labels=['Female','Male']
|
||
|
|
||
|
|
||
|
def load_model(self, model_path):
|
||
|
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
||
|
num_ftrs = model.fc.in_features
|
||
|
model.fc = nn.Linear(num_ftrs, 3) # 输出为性别和年龄
|
||
|
model.load_state_dict(torch.load(model_path))
|
||
|
model.eval()
|
||
|
return model
|
||
|
|
||
|
def preprocess_image(self, image):
|
||
|
preprocess = transforms.Compose([
|
||
|
transforms.Resize(256),
|
||
|
transforms.CenterCrop(224),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
|
])
|
||
|
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||
|
input_tensor = preprocess(image)
|
||
|
input_batch = input_tensor.unsqueeze(0)
|
||
|
return input_batch
|
||
|
|
||
|
def predict(self, face):
|
||
|
input_batch = self.preprocess_image(face)
|
||
|
if torch.cuda.is_available():
|
||
|
input_batch = input_batch.to('cuda')
|
||
|
self.model.to('cuda')
|
||
|
|
||
|
with torch.no_grad():
|
||
|
output = self.model(input_batch)
|
||
|
gender_preds = output[:, :2]
|
||
|
age_preds = output[:, -1]
|
||
|
gender = gender_preds.argmax(dim=1).item()
|
||
|
age = age_preds.item()
|
||
|
return self.gender_labels[gender], age, self.age_group(age)
|
||
|
|
||
|
def age_group(self, age):
|
||
|
if age <= 18:
|
||
|
return 'Teenager'
|
||
|
elif age <= 59:
|
||
|
return 'Adult'
|
||
|
else:
|
||
|
return 'Senior'
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# 创建 AgeGenderPredictor 类的实例
|
||
|
predictor = AgeGenderPredictor('megaage_model_epoch99.pth')
|
||
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
||
|
# 打开摄像头
|
||
|
cap = cv2.VideoCapture(0)
|
||
|
|
||
|
while True:
|
||
|
# 读取一帧
|
||
|
ret, frame = cap.read()
|
||
|
if not ret:
|
||
|
break
|
||
|
|
||
|
# 进行人脸检测
|
||
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||
|
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
||
|
|
||
|
# 对于检测到的每一个人脸
|
||
|
for (x, y, w, h) in faces:
|
||
|
# 提取人脸 ROI
|
||
|
face = frame[y:y + h, x:x + w]
|
||
|
gender, age, age_group = predictor.predict(face)
|
||
|
|
||
|
cv2.putText(frame, f'Gender: {gender}, Age: {int(age)}, Age Group: {age_group}', (x, y - 10),
|
||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
|
||
|
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||
|
|
||
|
# 显示帧
|
||
|
cv2.imshow('Webcam', frame)
|
||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
|
break
|
||
|
|
||
|
# 释放摄像头并关闭所有窗口
|
||
|
cap.release()
|
||
|
cv2.destroyAllWindows()
|