algorithm_system_server/algorithm/Unetliversegmaster/main_wy.py

53 lines
1.6 KiB
Python
Raw Permalink Normal View History

2024-06-21 10:06:54 +08:00
import torch
import argparse
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
from common_tools import transform_invert
import PIL.Image as Image
from datetime import datetime
class ImageSegmentation:
val_interval = 1
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 均为灰度图像只需要转换为tensor
x_transforms = transforms.ToTensor()
y_transforms = transforms.ToTensor()
# 显示模型的输出结果
def test():
model = Unet(1, 1)
model.load_state_dict(torch.load('./model/weights_100.pth', map_location='cuda'))
x_path = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val/Data/P8_T1_00070.png"
img_x = Image.open(x_path).convert('L')
img_x = x_transforms(img_x)
save_root = './data/predict'
# 获取当前时间
current_time = datetime.now()
# 将当前时间格式化为字符串
time_str = current_time.strftime("%Y%m%d_%H%M%S")
model.eval()
plt.ion()
with torch.no_grad():
img_x = img_x.unsqueeze(0)
x = img_x.type(torch.FloatTensor)
y = model(x)
x = torch.squeeze(x)
x = x.unsqueeze(0)
img_x = transform_invert(x, x_transforms)
img_y = torch.squeeze(y).numpy()
save_path = os.path.join(save_root, "predict_%s_.png" % time_str)
cv2.imwrite(save_path, img_y * 255)