algorithm_system_server/algorithm/image_segmentation.py

83 lines
2.5 KiB
Python

import torch
import argparse
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from algorithm.Unetliversegmaster.unet import Unet
from algorithm.Unetliversegmaster.dataset import LiverDataset
from algorithm.Unetliversegmaster.common_tools import transform_invert
import PIL.Image as Image
from datetime import datetime
from read_data import LoadImages, LoadStreams
import torch.backends.cudnn as cudnn
import argparse
class ImageSegmentation():
def __init__(self,video_path=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.x_transforms = transforms.ToTensor()
self.model = Unet(1, 1)
self.model.load_state_dict(torch.load('/home/shared/wy/flask_web/algorithm/Unetliversegmaster/model/weights_100.pth', map_location='cuda'))
if video_path is not None:
self.video_name = video_path
else:
self.video_name = 'vid2.mp4' # A default video file
# self.dataset = LoadImages(self.video_name)
self.dataset = cv2.imread(self.video_name, cv2.IMREAD_GRAYSCALE)
def use_webcam(self, source):
# self.dataset.release() # Release any existing video capture
#self.cap = cv2.VideoCapture(0) # Open default webcam
# print('use_webcam')
source = source
self.imgsz = 640
cudnn.benchmark = True
self.dataset = LoadStreams(source, img_size=self.imgsz)
self.flag = 1
return model
def class_to_label(self, x):
return self.classes[int(x)]
def get_frame(self):
# print(self.dataset.mode)
# print(self.dataset)
img = self.dataset
pil_image = Image.fromarray(img)
img_x = pil_image.convert('L')
img_x = self.x_transforms(img_x)
self.model.eval()
with torch.no_grad():
img_x = img_x.unsqueeze(0)
x = img_x.type(torch.FloatTensor)
y = self.model(x)
img_y = torch.squeeze(y).detach().numpy()
cv2.imwrite('test111111111111.png', img_y * 255)
ret, jpeg = cv2.imencode(".png", img_y* 255)
# save_path = os.path.join(save_root, "predict_%s_.png" % time_str)
# cv2.imwrite(save_path, img_y * 255)
return jpeg.tobytes(), ""