tjy/Emotion/FacialEmotion/predict_api.py

69 lines
2.3 KiB
Python

import os
import json
import uuid
import cv2
import torch
from PIL import Image
from torchvision import transforms
from model import mobile_vit_small as create_model
class ImagePredictor:
def __init__(self, model_path, class_indices_path, img_size=224):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.img_size = img_size
self.data_transform = transforms.Compose([
transforms.Resize(int(self.img_size * 1.14)),
transforms.CenterCrop(self.img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load class indices
with open(class_indices_path, "r",encoding="utf-8") as f:
self.class_indict = json.load(f)
# Load model
self.model = self.load_model(model_path)
def load_model(self, model_path):
model = create_model(num_classes=9).to(self.device)
model.load_state_dict(torch.load(model_path, map_location=self.device))
model.eval()
return model
def predict(self, cv2_image):
# Convert cv2 image to PIL image
image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
img = self.data_transform(image)
img = torch.unsqueeze(img, dim=0)
# Predict class
with torch.no_grad():
output = torch.squeeze(self.model(img.to(self.device))).cpu()
probabilities = torch.softmax(output, dim=0)
top_prob, top_catid = torch.topk(probabilities, 1)
# Predict class
with torch.no_grad():
output = torch.squeeze(self.model(img.to(self.device))).cpu()
probabilities = torch.softmax(output, dim=0)
top_prob, top_catid = torch.topk(probabilities, 1)
# Top 1 result
result = {
"name": self.class_indict[str(top_catid[0].item())],
"score": top_prob[0].item(),
"label": top_catid[0].item()
}
# Results dictionary
results = {"result": result, "log_id": str(uuid.uuid1())}
return results
# Example usage:
# predictor = ImagePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json")
# result = predictor.predict("../tulip.jpg")
# print(result)