algorithm_system_server/algorithm/Unetliversegmaster/test.py

64 lines
1.9 KiB
Python
Raw Normal View History

2024-06-21 10:06:54 +08:00
import torch
import argparse
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
from common_tools import transform_invert
import PIL.Image as Image
val_interval = 1
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 均为灰度图像只需要转换为tensor
x_transforms = transforms.ToTensor()
y_transforms = transforms.ToTensor()
# 显示模型的输出结果
def test(args):
model = Unet(1, 1)
model.load_state_dict(torch.load(args.ckpt, map_location='cuda'))
image_path = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val_data/Data/P1_T1_00038.png"
img_x = Image.open(image_path).convert('L')
img_x1 = x_transforms(img_x)
dataloaders = DataLoader(img_x, batch_size=1)
save_root = './data/predict'
model.eval()
plt.ion()
index = 0
with torch.no_grad():
x = img_x1.type(torch.FloatTensor)
y = model(x)
x = torch.squeeze(x)
x = x.unsqueeze(0)
img_x = transform_invert(x, x_transforms)
img_y = torch.squeeze(y).numpy()
save_path = os.path.join(save_root, "predict_%d_re.png" % index)
cv2.imwrite(save_path, img_y * 255)
index = index + 1
if __name__ == '__main__':
#参数解析
parse = argparse.ArgumentParser()
parse.add_argument("--action", type=str, help="train, test or dice", default="test")
parse.add_argument("--ckpt", type=str, help="the path of model weight file", default="./model/weights_100.pth")
# parse.add_argument("--ckpt", type=str, help="the path of model weight file")
args = parse.parse_args()
if args.action == "train":
train(args)
elif args.action == "test":
test(args)
elif args.action == "dice":
dice_calc(args)