import cv2 import torch from PIL import Image from torchvision import transforms, models import torch.nn as nn class AgeGenderPredictor: def __init__(self, model_path): self.model = self.load_model(model_path) self.gender_labels=['Female','Male'] def load_model(self, model_path): model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 3) # 输出为性别和年龄 model.load_state_dict(torch.load(model_path)) model.eval() return model def preprocess_image(self, image): preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0) return input_batch def predict(self, face): input_batch = self.preprocess_image(face) if torch.cuda.is_available(): input_batch = input_batch.to('cuda') self.model.to('cuda') with torch.no_grad(): output = self.model(input_batch) gender_preds = output[:, :2] age_preds = output[:, -1] gender = gender_preds.argmax(dim=1).item() age = age_preds.item() return self.gender_labels[gender], age, self.age_group(age) def age_group(self, age): if age <= 18: return 'Teenager' elif age <= 59: return 'Adult' else: return 'Senior' if __name__ == '__main__': # 创建 AgeGenderPredictor 类的实例 predictor = AgeGenderPredictor('megaage_model_epoch99.pth') face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') # 打开摄像头 cap = cv2.VideoCapture(0) while True: # 读取一帧 ret, frame = cap.read() if not ret: break # 进行人脸检测 gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) # 对于检测到的每一个人脸 for (x, y, w, h) in faces: # 提取人脸 ROI face = frame[y:y + h, x:x + w] gender, age, age_group = predictor.predict(face) cv2.putText(frame, f'Gender: {gender}, Age: {int(age)}, Age Group: {age_group}', (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2) cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) # 显示帧 cv2.imshow('Webcam', frame) if cv2.waitKey(1) & 0xFF == ord('q'): break # 释放摄像头并关闭所有窗口 cap.release() cv2.destroyAllWindows()