import os import json import uuid import torch from PIL import Image from torchvision import transforms from model import mobile_vit_small as create_model class ImagePredictor: def __init__(self, model_path, class_indices_path, img_size=224): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.img_size = img_size self.data_transform = transforms.Compose([ transforms.Resize(int(self.img_size * 1.14)), transforms.CenterCrop(self.img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load class indices with open(class_indices_path, "r",encoding="utf-8") as f: self.class_indict = json.load(f) # Load model self.model = self.load_model(model_path) def load_model(self, model_path): model = create_model(num_classes=3).to(self.device) model.load_state_dict(torch.load(model_path, map_location=self.device)) model.eval() return model def predict_img(self, image_path): # Load and transform image assert os.path.exists(image_path), f"file: '{image_path}' does not exist." img = Image.open(image_path).convert('RGB') img = self.data_transform(img) img = torch.unsqueeze(img, dim=0) # Predict class with torch.no_grad(): output = torch.squeeze(self.model(img.to(self.device))).cpu() probabilities = torch.softmax(output, dim=0) top_prob, top_catid = torch.topk(probabilities, 5) # Top 5 results top5 = [] for i in range(top_prob.size(0)): top5.append({ "name": self.class_indict[str(top_catid[i].item())], "score": top_prob[i].item(), "label": top_catid[i].item() }) # Results dictionary results = {"result": top5, "log_id": str(uuid.uuid1())} return results def predict(self, np_image): # Convert numpy image to PIL image img = Image.fromarray(np_image).convert('RGB') # Transform image img = self.data_transform(img) img = torch.unsqueeze(img, dim=0) # Predict class with torch.no_grad(): output = torch.squeeze(self.model(img.to(self.device))).cpu() probabilities = torch.softmax(output, dim=0) top_prob, top_catid = torch.topk(probabilities, 1) # Top 5 results top5 = [] for i in range(top_prob.size(0)): top5.append({ "name": self.class_indict[str(top_catid[i].item())], "score": top_prob[i].item(), "label": top_catid[i].item() }) # Results dictionary results = {"result": top5, "log_id": str(uuid.uuid1())} return results # Example usage: # predictor = ImagePredictor(model_path="./weights/best_model.pth", class_indices_path="./class_indices.json") # result = predictor.predict("../tulip.jpg") # print(result)