91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
import os
|
|
import json
|
|
import uuid
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from model import mobile_vit_small as create_model
|
|
|
|
class ImagePredictor:
|
|
def __init__(self, model_path, class_indices_path, img_size=224):
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
self.img_size = img_size
|
|
self.data_transform = transforms.Compose([
|
|
transforms.Resize(int(self.img_size * 1.14)),
|
|
transforms.CenterCrop(self.img_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
# Load class indices
|
|
with open(class_indices_path, "r",encoding="utf-8") as f:
|
|
self.class_indict = json.load(f)
|
|
# Load model
|
|
self.model = self.load_model(model_path)
|
|
|
|
def load_model(self, model_path):
|
|
|
|
model = create_model(num_classes=3).to(self.device)
|
|
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
model.eval()
|
|
return model
|
|
|
|
def predict_img(self, image_path):
|
|
# Load and transform image
|
|
assert os.path.exists(image_path), f"file: '{image_path}' does not exist."
|
|
img = Image.open(image_path).convert('RGB')
|
|
img = self.data_transform(img)
|
|
img = torch.unsqueeze(img, dim=0)
|
|
|
|
# Predict class
|
|
with torch.no_grad():
|
|
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
|
probabilities = torch.softmax(output, dim=0)
|
|
top_prob, top_catid = torch.topk(probabilities, 5)
|
|
|
|
# Top 5 results
|
|
top5 = []
|
|
for i in range(top_prob.size(0)):
|
|
top5.append({
|
|
"name": self.class_indict[str(top_catid[i].item())],
|
|
"score": top_prob[i].item(),
|
|
"label": top_catid[i].item()
|
|
})
|
|
|
|
# Results dictionary
|
|
|
|
results = {"result": top5, "log_id": str(uuid.uuid1())}
|
|
|
|
return results
|
|
def predict(self, np_image):
|
|
# Convert numpy image to PIL image
|
|
img = Image.fromarray(np_image).convert('RGB')
|
|
|
|
# Transform image
|
|
img = self.data_transform(img)
|
|
img = torch.unsqueeze(img, dim=0)
|
|
|
|
# Predict class
|
|
with torch.no_grad():
|
|
output = torch.squeeze(self.model(img.to(self.device))).cpu()
|
|
probabilities = torch.softmax(output, dim=0)
|
|
top_prob, top_catid = torch.topk(probabilities, 1)
|
|
|
|
# Top 5 results
|
|
top5 = []
|
|
for i in range(top_prob.size(0)):
|
|
top5.append({
|
|
"name": self.class_indict[str(top_catid[i].item())],
|
|
"score": top_prob[i].item(),
|
|
"label": top_catid[i].item()
|
|
})
|
|
|
|
# Results dictionary
|
|
results = {"result": top5, "log_id": str(uuid.uuid1())}
|
|
|
|
return results
|
|
|
|
# Example usage:
|
|
# predictor = ImagePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json")
|
|
# result = predictor.predict("../tulip.jpg")
|
|
# print(result)
|