import torch import argparse import cv2 import os import numpy as np import matplotlib.pyplot as plt from torch.utils.data import DataLoader from torch import nn, optim from torchvision.transforms import transforms from algorithm.Unetliversegmaster.unet import Unet from algorithm.Unetliversegmaster.dataset import LiverDataset from algorithm.Unetliversegmaster.common_tools import transform_invert import PIL.Image as Image from datetime import datetime from read_data import LoadImages, LoadStreams import torch.backends.cudnn as cudnn import argparse class ImageSegmentation(): def __init__(self,video_path=None): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.x_transforms = transforms.ToTensor() self.model = Unet(1, 1) self.model.load_state_dict(torch.load('/home/shared/wy/flask_web/algorithm/Unetliversegmaster/model/weights_100.pth', map_location='cuda')) if video_path is not None: self.video_name = video_path else: self.video_name = 'vid2.mp4' # A default video file # self.dataset = LoadImages(self.video_name) self.dataset = cv2.imread(self.video_name, cv2.IMREAD_GRAYSCALE) def use_webcam(self, source): # self.dataset.release() # Release any existing video capture #self.cap = cv2.VideoCapture(0) # Open default webcam # print('use_webcam') source = source self.imgsz = 640 cudnn.benchmark = True self.dataset = LoadStreams(source, img_size=self.imgsz) self.flag = 1 return model def class_to_label(self, x): return self.classes[int(x)] def get_frame(self): # print(self.dataset.mode) # print(self.dataset) img = self.dataset pil_image = Image.fromarray(img) img_x = pil_image.convert('L') img_x = self.x_transforms(img_x) self.model.eval() with torch.no_grad(): img_x = img_x.unsqueeze(0) x = img_x.type(torch.FloatTensor) y = self.model(x) img_y = torch.squeeze(y).detach().numpy() cv2.imwrite('test111111111111.png', img_y * 255) ret, jpeg = cv2.imencode(".png", img_y* 255) # save_path = os.path.join(save_root, "predict_%s_.png" % time_str) # cv2.imwrite(save_path, img_y * 255) return jpeg.tobytes(), ""