53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
import torch
|
||
import argparse
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from torch.utils.data import DataLoader
|
||
from torch import nn, optim
|
||
from torchvision.transforms import transforms
|
||
from unet import Unet
|
||
from dataset import LiverDataset
|
||
from common_tools import transform_invert
|
||
import PIL.Image as Image
|
||
from datetime import datetime
|
||
|
||
class ImageSegmentation:
|
||
|
||
val_interval = 1
|
||
# 是否使用cuda
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# 均为灰度图像,只需要转换为tensor
|
||
x_transforms = transforms.ToTensor()
|
||
y_transforms = transforms.ToTensor()
|
||
|
||
|
||
# 显示模型的输出结果
|
||
def test():
|
||
model = Unet(1, 1)
|
||
model.load_state_dict(torch.load('./model/weights_100.pth', map_location='cuda'))
|
||
x_path = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val/Data/P8_T1_00070.png"
|
||
img_x = Image.open(x_path).convert('L')
|
||
img_x = x_transforms(img_x)
|
||
|
||
save_root = './data/predict'
|
||
|
||
# 获取当前时间
|
||
current_time = datetime.now()
|
||
|
||
# 将当前时间格式化为字符串
|
||
time_str = current_time.strftime("%Y%m%d_%H%M%S")
|
||
model.eval()
|
||
plt.ion()
|
||
with torch.no_grad():
|
||
img_x = img_x.unsqueeze(0)
|
||
x = img_x.type(torch.FloatTensor)
|
||
y = model(x)
|
||
x = torch.squeeze(x)
|
||
x = x.unsqueeze(0)
|
||
img_x = transform_invert(x, x_transforms)
|
||
img_y = torch.squeeze(y).numpy()
|
||
save_path = os.path.join(save_root, "predict_%s_.png" % time_str)
|
||
cv2.imwrite(save_path, img_y * 255) |