64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
import torch
|
||
import argparse
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from torch.utils.data import DataLoader
|
||
from torch import nn, optim
|
||
from torchvision.transforms import transforms
|
||
from unet import Unet
|
||
from dataset import LiverDataset
|
||
from common_tools import transform_invert
|
||
import PIL.Image as Image
|
||
|
||
val_interval = 1
|
||
# 是否使用cuda
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# 均为灰度图像,只需要转换为tensor
|
||
x_transforms = transforms.ToTensor()
|
||
y_transforms = transforms.ToTensor()
|
||
|
||
|
||
# 显示模型的输出结果
|
||
def test(args):
|
||
model = Unet(1, 1)
|
||
model.load_state_dict(torch.load(args.ckpt, map_location='cuda'))
|
||
image_path = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val_data/Data/P1_T1_00038.png"
|
||
img_x = Image.open(image_path).convert('L')
|
||
img_x1 = x_transforms(img_x)
|
||
dataloaders = DataLoader(img_x, batch_size=1)
|
||
|
||
save_root = './data/predict'
|
||
|
||
model.eval()
|
||
plt.ion()
|
||
index = 0
|
||
with torch.no_grad():
|
||
|
||
x = img_x1.type(torch.FloatTensor)
|
||
y = model(x)
|
||
x = torch.squeeze(x)
|
||
x = x.unsqueeze(0)
|
||
img_x = transform_invert(x, x_transforms)
|
||
img_y = torch.squeeze(y).numpy()
|
||
save_path = os.path.join(save_root, "predict_%d_re.png" % index)
|
||
cv2.imwrite(save_path, img_y * 255)
|
||
index = index + 1
|
||
|
||
if __name__ == '__main__':
|
||
#参数解析
|
||
parse = argparse.ArgumentParser()
|
||
parse.add_argument("--action", type=str, help="train, test or dice", default="test")
|
||
parse.add_argument("--ckpt", type=str, help="the path of model weight file", default="./model/weights_100.pth")
|
||
# parse.add_argument("--ckpt", type=str, help="the path of model weight file")
|
||
args = parse.parse_args()
|
||
|
||
if args.action == "train":
|
||
train(args)
|
||
elif args.action == "test":
|
||
test(args)
|
||
elif args.action == "dice":
|
||
dice_calc(args)
|