53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
|
import torch
|
|||
|
import argparse
|
|||
|
import cv2
|
|||
|
import os
|
|||
|
import numpy as np
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
from torch import nn, optim
|
|||
|
from torchvision.transforms import transforms
|
|||
|
from unet import Unet
|
|||
|
from dataset import LiverDataset
|
|||
|
from common_tools import transform_invert
|
|||
|
import PIL.Image as Image
|
|||
|
from datetime import datetime
|
|||
|
|
|||
|
class ImageSegmentation:
|
|||
|
|
|||
|
val_interval = 1
|
|||
|
# 是否使用cuda
|
|||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
|
|||
|
# 均为灰度图像,只需要转换为tensor
|
|||
|
x_transforms = transforms.ToTensor()
|
|||
|
y_transforms = transforms.ToTensor()
|
|||
|
|
|||
|
|
|||
|
# 显示模型的输出结果
|
|||
|
def test():
|
|||
|
model = Unet(1, 1)
|
|||
|
model.load_state_dict(torch.load('./model/weights_100.pth', map_location='cuda'))
|
|||
|
x_path = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val/Data/P8_T1_00070.png"
|
|||
|
img_x = Image.open(x_path).convert('L')
|
|||
|
img_x = x_transforms(img_x)
|
|||
|
|
|||
|
save_root = './data/predict'
|
|||
|
|
|||
|
# 获取当前时间
|
|||
|
current_time = datetime.now()
|
|||
|
|
|||
|
# 将当前时间格式化为字符串
|
|||
|
time_str = current_time.strftime("%Y%m%d_%H%M%S")
|
|||
|
model.eval()
|
|||
|
plt.ion()
|
|||
|
with torch.no_grad():
|
|||
|
img_x = img_x.unsqueeze(0)
|
|||
|
x = img_x.type(torch.FloatTensor)
|
|||
|
y = model(x)
|
|||
|
x = torch.squeeze(x)
|
|||
|
x = x.unsqueeze(0)
|
|||
|
img_x = transform_invert(x, x_transforms)
|
|||
|
img_y = torch.squeeze(y).numpy()
|
|||
|
save_path = os.path.join(save_root, "predict_%s_.png" % time_str)
|
|||
|
cv2.imwrite(save_path, img_y * 255)
|