tjy/demo/apis/age/AgeGenderPredictor.py

89 lines
3.0 KiB
Python

import cv2
import torch
from PIL import Image
from torchvision import transforms, models
import torch.nn as nn
class AgeGenderPredictor:
def __init__(self, model_path):
self.model = self.load_model(model_path)
self.gender_labels=['Female','Male']
def load_model(self, model_path):
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 3) # 输出为性别和年龄
model.load_state_dict(torch.load(model_path))
model.eval()
return model
def preprocess_image(self, image):
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
return input_batch
def predict(self, face):
input_batch = self.preprocess_image(face)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
self.model.to('cuda')
with torch.no_grad():
output = self.model(input_batch)
gender_preds = output[:, :2]
age_preds = output[:, -1]
gender = gender_preds.argmax(dim=1).item()
age = age_preds.item()
return self.gender_labels[gender], age, self.age_group(age)
def age_group(self, age):
if age <= 18:
return 'Teenager'
elif age <= 59:
return 'Adult'
else:
return 'Senior'
if __name__ == '__main__':
# 创建 AgeGenderPredictor 类的实例
predictor = AgeGenderPredictor('../../weights/megaage_model_epoch99.pth')
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
# 打开摄像头
cap = cv2.VideoCapture(0)
while True:
# 读取一帧
ret, frame = cap.read()
if not ret:
break
# 进行人脸检测
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
# 对于检测到的每一个人脸
for (x, y, w, h) in faces:
# 提取人脸 ROI
face = frame[y:y + h, x:x + w]
gender, age, age_group = predictor.predict(face)
cv2.putText(frame, f'Gender: {gender}, Age: {int(age)}, Age Group: {age_group}', (x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
# 显示帧
cv2.imshow('Webcam', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# 释放摄像头并关闭所有窗口
cap.release()
cv2.destroyAllWindows()