import torch import argparse import cv2 import os import numpy as np import matplotlib.pyplot as plt from torch.utils.data import DataLoader from torch import nn, optim from torchvision.transforms import transforms from unet import Unet from dataset import LiverDataset from common_tools import transform_invert import PIL.Image as Image from datetime import datetime class ImageSegmentation: val_interval = 1 # 是否使用cuda device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 均为灰度图像,只需要转换为tensor x_transforms = transforms.ToTensor() y_transforms = transforms.ToTensor() # 显示模型的输出结果 def test(): model = Unet(1, 1) model.load_state_dict(torch.load('./model/weights_100.pth', map_location='cuda')) x_path = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val/Data/P8_T1_00070.png" img_x = Image.open(x_path).convert('L') img_x = x_transforms(img_x) save_root = './data/predict' # 获取当前时间 current_time = datetime.now() # 将当前时间格式化为字符串 time_str = current_time.strftime("%Y%m%d_%H%M%S") model.eval() plt.ion() with torch.no_grad(): img_x = img_x.unsqueeze(0) x = img_x.type(torch.FloatTensor) y = model(x) x = torch.squeeze(x) x = x.unsqueeze(0) img_x = transform_invert(x, x_transforms) img_y = torch.squeeze(y).numpy() save_path = os.path.join(save_root, "predict_%s_.png" % time_str) cv2.imwrite(save_path, img_y * 255)