196 lines
6.6 KiB
Python
196 lines
6.6 KiB
Python
|
import torch
|
|||
|
import argparse
|
|||
|
import cv2
|
|||
|
import os
|
|||
|
import numpy as np
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
from torch import nn, optim
|
|||
|
from torchvision.transforms import transforms
|
|||
|
from unet import Unet
|
|||
|
from dataset import LiverDataset
|
|||
|
from common_tools import transform_invert
|
|||
|
|
|||
|
|
|||
|
def makedir(dir):
|
|||
|
if not os.path.exists(dir):
|
|||
|
os.mkdir(dir)
|
|||
|
|
|||
|
|
|||
|
val_interval = 1
|
|||
|
# 是否使用cuda
|
|||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
|
|||
|
# 均为灰度图像,只需要转换为tensor
|
|||
|
x_transforms = transforms.ToTensor()
|
|||
|
y_transforms = transforms.ToTensor()
|
|||
|
|
|||
|
train_curve = list()
|
|||
|
valid_curve = list()
|
|||
|
|
|||
|
|
|||
|
def train_model(model, criterion, optimizer, dataload, num_epochs=100):
|
|||
|
makedir('./model')
|
|||
|
model_path = "./model/weights_100.pth"
|
|||
|
if os.path.exists(model_path):
|
|||
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
|||
|
start_epoch = 20
|
|||
|
print('加载成功!')
|
|||
|
else:
|
|||
|
start_epoch = 0
|
|||
|
print('无保存模型,将从头开始训练!')
|
|||
|
|
|||
|
for epoch in range(start_epoch+1, num_epochs):
|
|||
|
print('Epoch {}/{}'.format(epoch, num_epochs))
|
|||
|
print('-' * 10)
|
|||
|
dt_size = len(dataload.dataset)
|
|||
|
epoch_loss = 0
|
|||
|
step = 0
|
|||
|
for x, y in dataload:
|
|||
|
step += 1
|
|||
|
inputs = x.to(device)
|
|||
|
labels = y.to(device)
|
|||
|
# zero the parameter gradients
|
|||
|
optimizer.zero_grad()
|
|||
|
# forward
|
|||
|
outputs = model(inputs)
|
|||
|
loss = criterion(outputs, labels)
|
|||
|
loss.backward()
|
|||
|
optimizer.step()
|
|||
|
epoch_loss += loss.item()
|
|||
|
train_curve.append(loss.item())
|
|||
|
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
|
|||
|
print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
|
|||
|
if (epoch + 1) % 50 == 0:
|
|||
|
torch.save(model.state_dict(), './model/weights_%d.pth' % (epoch + 1))
|
|||
|
|
|||
|
# Validate the model
|
|||
|
valid_dataset = LiverDataset("data/val", transform=x_transforms, target_transform=y_transforms)
|
|||
|
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=True)
|
|||
|
if (epoch + 2) % val_interval == 0:
|
|||
|
loss_val = 0.
|
|||
|
model.eval()
|
|||
|
with torch.no_grad():
|
|||
|
step_val = 0
|
|||
|
for x, y in valid_loader:
|
|||
|
step_val += 1
|
|||
|
x = x.type(torch.FloatTensor)
|
|||
|
inputs = x.to(device)
|
|||
|
labels = y.to(device)
|
|||
|
outputs = model(inputs)
|
|||
|
loss = criterion(outputs, labels)
|
|||
|
loss_val += loss.item()
|
|||
|
|
|||
|
valid_curve.append(loss_val)
|
|||
|
print("epoch %d valid_loss:%0.3f" % (epoch, loss_val / step_val))
|
|||
|
|
|||
|
train_x = range(len(train_curve))
|
|||
|
train_y = train_curve
|
|||
|
|
|||
|
train_iters = len(dataload)
|
|||
|
valid_x = np.arange(1, len(
|
|||
|
valid_curve) + 1) * train_iters * val_interval # 由于valid中记录的是EpochLoss,需要对记录点进行转换到iterations
|
|||
|
valid_y = valid_curve
|
|||
|
|
|||
|
plt.plot(train_x, train_y, label='Train')
|
|||
|
plt.plot(valid_x, valid_y, label='Valid')
|
|||
|
|
|||
|
plt.legend(loc='upper right')
|
|||
|
plt.ylabel('loss value')
|
|||
|
plt.xlabel('Iteration')
|
|||
|
plt.show()
|
|||
|
return model
|
|||
|
|
|||
|
|
|||
|
# 训练模型
|
|||
|
def train(args):
|
|||
|
model = Unet(1, 1).to(device)
|
|||
|
batch_size = args.batch_size
|
|||
|
criterion = nn.BCEWithLogitsLoss()
|
|||
|
optimizer = optim.Adam(model.parameters())
|
|||
|
liver_dataset = LiverDataset("./data/train", transform=x_transforms, target_transform=y_transforms)
|
|||
|
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
|||
|
train_model(model, criterion, optimizer, dataloaders)
|
|||
|
|
|||
|
|
|||
|
# 显示模型的输出结果
|
|||
|
def test(args):
|
|||
|
model = Unet(1, 1)
|
|||
|
model.load_state_dict(torch.load(args.ckpt, map_location='cuda'))
|
|||
|
liver_dataset = LiverDataset("data/val", transform=x_transforms, target_transform=y_transforms)
|
|||
|
dataloaders = DataLoader(liver_dataset, batch_size=1)
|
|||
|
|
|||
|
save_root = './data/predict'
|
|||
|
|
|||
|
model.eval()
|
|||
|
plt.ion()
|
|||
|
index = 0
|
|||
|
with torch.no_grad():
|
|||
|
for x, ground in dataloaders:
|
|||
|
x = x.type(torch.FloatTensor)
|
|||
|
y = model(x)
|
|||
|
x = torch.squeeze(x)
|
|||
|
x = x.unsqueeze(0)
|
|||
|
ground = torch.squeeze(ground)
|
|||
|
ground = ground.unsqueeze(0)
|
|||
|
img_ground = transform_invert(ground, y_transforms)
|
|||
|
img_x = transform_invert(x, x_transforms)
|
|||
|
img_y = torch.squeeze(y).numpy()
|
|||
|
# cv2.imshow('img', img_y)
|
|||
|
src_path = os.path.join(save_root, "predict_%d_s.png" % index)
|
|||
|
save_path = os.path.join(save_root, "predict_%d_o.png" % index)
|
|||
|
ground_path = os.path.join(save_root, "predict_%d_g.png" % index)
|
|||
|
img_ground.save(ground_path)
|
|||
|
# img_x.save(src_path)
|
|||
|
cv2.imwrite(save_path, img_y * 255)
|
|||
|
index = index + 1
|
|||
|
# plt.imshow(img_y)
|
|||
|
# plt.pause(0.5)
|
|||
|
# plt.show()
|
|||
|
|
|||
|
|
|||
|
# 计算Dice系数
|
|||
|
def dice_calc(args):
|
|||
|
root = './data/predict'
|
|||
|
nums = len(os.listdir(root)) // 3
|
|||
|
dice = list()
|
|||
|
dice_mean = 0
|
|||
|
for i in range(nums):
|
|||
|
ground_path = os.path.join(root, "predict_%d_g.png" % i)
|
|||
|
predict_path = os.path.join(root, "predict_%d_o.png" % i)
|
|||
|
img_ground = cv2.imread(ground_path)
|
|||
|
img_predict = cv2.imread(predict_path)
|
|||
|
intersec = 0
|
|||
|
x = 0
|
|||
|
y = 0
|
|||
|
for w in range(256):
|
|||
|
for h in range(256):
|
|||
|
intersec += img_ground.item(w, h, 1) * img_predict.item(w, h, 1) / (255 * 255)
|
|||
|
x += img_ground.item(w, h, 1) / 255
|
|||
|
y += img_predict.item(w, h, 1) / 255
|
|||
|
if x + y == 0:
|
|||
|
current_dice = 1
|
|||
|
else:
|
|||
|
current_dice = round(2 * intersec / (x + y), 3)
|
|||
|
dice_mean += current_dice
|
|||
|
dice.append(current_dice)
|
|||
|
dice_mean /= len(dice)
|
|||
|
print(dice)
|
|||
|
print(round(dice_mean, 3))
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
#参数解析
|
|||
|
parse = argparse.ArgumentParser()
|
|||
|
parse.add_argument("--action", type=str, help="train, test or dice", default="test")
|
|||
|
parse.add_argument("--batch_size", type=int, default=4)
|
|||
|
parse.add_argument("--ckpt", type=str, help="the path of model weight file", default="./model/weights_100.pth")
|
|||
|
# parse.add_argument("--ckpt", type=str, help="the path of model weight file")
|
|||
|
args = parse.parse_args()
|
|||
|
|
|||
|
if args.action == "train":
|
|||
|
train(args)
|
|||
|
elif args.action == "test":
|
|||
|
test(args)
|
|||
|
elif args.action == "dice":
|
|||
|
dice_calc(args)
|