64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
|
import torch
|
|||
|
import argparse
|
|||
|
import cv2
|
|||
|
import os
|
|||
|
import numpy as np
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
from torch import nn, optim
|
|||
|
from torchvision.transforms import transforms
|
|||
|
from unet import Unet
|
|||
|
from dataset import LiverDataset
|
|||
|
from common_tools import transform_invert
|
|||
|
import PIL.Image as Image
|
|||
|
|
|||
|
val_interval = 1
|
|||
|
# 是否使用cuda
|
|||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
|
|||
|
# 均为灰度图像,只需要转换为tensor
|
|||
|
x_transforms = transforms.ToTensor()
|
|||
|
y_transforms = transforms.ToTensor()
|
|||
|
|
|||
|
|
|||
|
# 显示模型的输出结果
|
|||
|
def test(args):
|
|||
|
model = Unet(1, 1)
|
|||
|
model.load_state_dict(torch.load(args.ckpt, map_location='cuda'))
|
|||
|
image_path = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val_data/Data/P1_T1_00038.png"
|
|||
|
img_x = Image.open(image_path).convert('L')
|
|||
|
img_x1 = x_transforms(img_x)
|
|||
|
dataloaders = DataLoader(img_x, batch_size=1)
|
|||
|
|
|||
|
save_root = './data/predict'
|
|||
|
|
|||
|
model.eval()
|
|||
|
plt.ion()
|
|||
|
index = 0
|
|||
|
with torch.no_grad():
|
|||
|
|
|||
|
x = img_x1.type(torch.FloatTensor)
|
|||
|
y = model(x)
|
|||
|
x = torch.squeeze(x)
|
|||
|
x = x.unsqueeze(0)
|
|||
|
img_x = transform_invert(x, x_transforms)
|
|||
|
img_y = torch.squeeze(y).numpy()
|
|||
|
save_path = os.path.join(save_root, "predict_%d_re.png" % index)
|
|||
|
cv2.imwrite(save_path, img_y * 255)
|
|||
|
index = index + 1
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
#参数解析
|
|||
|
parse = argparse.ArgumentParser()
|
|||
|
parse.add_argument("--action", type=str, help="train, test or dice", default="test")
|
|||
|
parse.add_argument("--ckpt", type=str, help="the path of model weight file", default="./model/weights_100.pth")
|
|||
|
# parse.add_argument("--ckpt", type=str, help="the path of model weight file")
|
|||
|
args = parse.parse_args()
|
|||
|
|
|||
|
if args.action == "train":
|
|||
|
train(args)
|
|||
|
elif args.action == "test":
|
|||
|
test(args)
|
|||
|
elif args.action == "dice":
|
|||
|
dice_calc(args)
|