import torch import argparse import cv2 import os import numpy as np import matplotlib.pyplot as plt from torch.utils.data import DataLoader from torch import nn, optim from torchvision.transforms import transforms from unet import Unet from dataset import LiverDataset from common_tools import transform_invert import PIL.Image as Image val_interval = 1 # 是否使用cuda device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 均为灰度图像,只需要转换为tensor x_transforms = transforms.ToTensor() y_transforms = transforms.ToTensor() # 显示模型的输出结果 def test(args): model = Unet(1, 1) model.load_state_dict(torch.load(args.ckpt, map_location='cuda')) image_path = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val_data/Data/P1_T1_00038.png" img_x = Image.open(image_path).convert('L') img_x1 = x_transforms(img_x) dataloaders = DataLoader(img_x, batch_size=1) save_root = './data/predict' model.eval() plt.ion() index = 0 with torch.no_grad(): x = img_x1.type(torch.FloatTensor) y = model(x) x = torch.squeeze(x) x = x.unsqueeze(0) img_x = transform_invert(x, x_transforms) img_y = torch.squeeze(y).numpy() save_path = os.path.join(save_root, "predict_%d_re.png" % index) cv2.imwrite(save_path, img_y * 255) index = index + 1 if __name__ == '__main__': #参数解析 parse = argparse.ArgumentParser() parse.add_argument("--action", type=str, help="train, test or dice", default="test") parse.add_argument("--ckpt", type=str, help="the path of model weight file", default="./model/weights_100.pth") # parse.add_argument("--ckpt", type=str, help="the path of model weight file") args = parse.parse_args() if args.action == "train": train(args) elif args.action == "test": test(args) elif args.action == "dice": dice_calc(args)