83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
|
import torch
|
||
|
import argparse
|
||
|
import cv2
|
||
|
import os
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
from torch.utils.data import DataLoader
|
||
|
from torch import nn, optim
|
||
|
from torchvision.transforms import transforms
|
||
|
from algorithm.Unetliversegmaster.unet import Unet
|
||
|
from algorithm.Unetliversegmaster.dataset import LiverDataset
|
||
|
from algorithm.Unetliversegmaster.common_tools import transform_invert
|
||
|
import PIL.Image as Image
|
||
|
from datetime import datetime
|
||
|
from read_data import LoadImages, LoadStreams
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
import argparse
|
||
|
|
||
|
|
||
|
class ImageSegmentation():
|
||
|
def __init__(self,video_path=None):
|
||
|
|
||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
self.x_transforms = transforms.ToTensor()
|
||
|
self.model = Unet(1, 1)
|
||
|
self.model.load_state_dict(torch.load('/home/shared/wy/flask_web/algorithm/Unetliversegmaster/model/weights_100.pth', map_location='cuda'))
|
||
|
if video_path is not None:
|
||
|
self.video_name = video_path
|
||
|
else:
|
||
|
self.video_name = 'vid2.mp4' # A default video file
|
||
|
|
||
|
|
||
|
# self.dataset = LoadImages(self.video_name)
|
||
|
self.dataset = cv2.imread(self.video_name, cv2.IMREAD_GRAYSCALE)
|
||
|
def use_webcam(self, source):
|
||
|
# self.dataset.release() # Release any existing video capture
|
||
|
#self.cap = cv2.VideoCapture(0) # Open default webcam
|
||
|
# print('use_webcam')
|
||
|
source = source
|
||
|
self.imgsz = 640
|
||
|
cudnn.benchmark = True
|
||
|
self.dataset = LoadStreams(source, img_size=self.imgsz)
|
||
|
self.flag = 1
|
||
|
|
||
|
|
||
|
return model
|
||
|
def class_to_label(self, x):
|
||
|
return self.classes[int(x)]
|
||
|
|
||
|
|
||
|
|
||
|
def get_frame(self):
|
||
|
|
||
|
|
||
|
# print(self.dataset.mode)
|
||
|
# print(self.dataset)
|
||
|
|
||
|
img = self.dataset
|
||
|
|
||
|
pil_image = Image.fromarray(img)
|
||
|
img_x = pil_image.convert('L')
|
||
|
|
||
|
img_x = self.x_transforms(img_x)
|
||
|
self.model.eval()
|
||
|
with torch.no_grad():
|
||
|
|
||
|
|
||
|
img_x = img_x.unsqueeze(0)
|
||
|
|
||
|
x = img_x.type(torch.FloatTensor)
|
||
|
|
||
|
y = self.model(x)
|
||
|
|
||
|
img_y = torch.squeeze(y).detach().numpy()
|
||
|
cv2.imwrite('test111111111111.png', img_y * 255)
|
||
|
|
||
|
ret, jpeg = cv2.imencode(".png", img_y* 255)
|
||
|
# save_path = os.path.join(save_root, "predict_%s_.png" % time_str)
|
||
|
# cv2.imwrite(save_path, img_y * 255)
|
||
|
|
||
|
return jpeg.tobytes(), ""
|
||
|
|