196 lines
6.6 KiB
Python
196 lines
6.6 KiB
Python
import torch
|
||
import argparse
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from torch.utils.data import DataLoader
|
||
from torch import nn, optim
|
||
from torchvision.transforms import transforms
|
||
from unet import Unet
|
||
from dataset import LiverDataset
|
||
from common_tools import transform_invert
|
||
|
||
|
||
def makedir(dir):
|
||
if not os.path.exists(dir):
|
||
os.mkdir(dir)
|
||
|
||
|
||
val_interval = 1
|
||
# 是否使用cuda
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# 均为灰度图像,只需要转换为tensor
|
||
x_transforms = transforms.ToTensor()
|
||
y_transforms = transforms.ToTensor()
|
||
|
||
train_curve = list()
|
||
valid_curve = list()
|
||
|
||
|
||
def train_model(model, criterion, optimizer, dataload, num_epochs=100):
|
||
makedir('./model')
|
||
model_path = "./model/weights_100.pth"
|
||
if os.path.exists(model_path):
|
||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||
start_epoch = 20
|
||
print('加载成功!')
|
||
else:
|
||
start_epoch = 0
|
||
print('无保存模型,将从头开始训练!')
|
||
|
||
for epoch in range(start_epoch+1, num_epochs):
|
||
print('Epoch {}/{}'.format(epoch, num_epochs))
|
||
print('-' * 10)
|
||
dt_size = len(dataload.dataset)
|
||
epoch_loss = 0
|
||
step = 0
|
||
for x, y in dataload:
|
||
step += 1
|
||
inputs = x.to(device)
|
||
labels = y.to(device)
|
||
# zero the parameter gradients
|
||
optimizer.zero_grad()
|
||
# forward
|
||
outputs = model(inputs)
|
||
loss = criterion(outputs, labels)
|
||
loss.backward()
|
||
optimizer.step()
|
||
epoch_loss += loss.item()
|
||
train_curve.append(loss.item())
|
||
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
|
||
print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
|
||
if (epoch + 1) % 50 == 0:
|
||
torch.save(model.state_dict(), './model/weights_%d.pth' % (epoch + 1))
|
||
|
||
# Validate the model
|
||
valid_dataset = LiverDataset("data/val", transform=x_transforms, target_transform=y_transforms)
|
||
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=True)
|
||
if (epoch + 2) % val_interval == 0:
|
||
loss_val = 0.
|
||
model.eval()
|
||
with torch.no_grad():
|
||
step_val = 0
|
||
for x, y in valid_loader:
|
||
step_val += 1
|
||
x = x.type(torch.FloatTensor)
|
||
inputs = x.to(device)
|
||
labels = y.to(device)
|
||
outputs = model(inputs)
|
||
loss = criterion(outputs, labels)
|
||
loss_val += loss.item()
|
||
|
||
valid_curve.append(loss_val)
|
||
print("epoch %d valid_loss:%0.3f" % (epoch, loss_val / step_val))
|
||
|
||
train_x = range(len(train_curve))
|
||
train_y = train_curve
|
||
|
||
train_iters = len(dataload)
|
||
valid_x = np.arange(1, len(
|
||
valid_curve) + 1) * train_iters * val_interval # 由于valid中记录的是EpochLoss,需要对记录点进行转换到iterations
|
||
valid_y = valid_curve
|
||
|
||
plt.plot(train_x, train_y, label='Train')
|
||
plt.plot(valid_x, valid_y, label='Valid')
|
||
|
||
plt.legend(loc='upper right')
|
||
plt.ylabel('loss value')
|
||
plt.xlabel('Iteration')
|
||
plt.show()
|
||
return model
|
||
|
||
|
||
# 训练模型
|
||
def train(args):
|
||
model = Unet(1, 1).to(device)
|
||
batch_size = args.batch_size
|
||
criterion = nn.BCEWithLogitsLoss()
|
||
optimizer = optim.Adam(model.parameters())
|
||
liver_dataset = LiverDataset("./data/train", transform=x_transforms, target_transform=y_transforms)
|
||
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
||
train_model(model, criterion, optimizer, dataloaders)
|
||
|
||
|
||
# 显示模型的输出结果
|
||
def test(args):
|
||
model = Unet(1, 1)
|
||
model.load_state_dict(torch.load(args.ckpt, map_location='cuda'))
|
||
liver_dataset = LiverDataset("data/val", transform=x_transforms, target_transform=y_transforms)
|
||
dataloaders = DataLoader(liver_dataset, batch_size=1)
|
||
|
||
save_root = './data/predict'
|
||
|
||
model.eval()
|
||
plt.ion()
|
||
index = 0
|
||
with torch.no_grad():
|
||
for x, ground in dataloaders:
|
||
x = x.type(torch.FloatTensor)
|
||
y = model(x)
|
||
x = torch.squeeze(x)
|
||
x = x.unsqueeze(0)
|
||
ground = torch.squeeze(ground)
|
||
ground = ground.unsqueeze(0)
|
||
img_ground = transform_invert(ground, y_transforms)
|
||
img_x = transform_invert(x, x_transforms)
|
||
img_y = torch.squeeze(y).numpy()
|
||
# cv2.imshow('img', img_y)
|
||
src_path = os.path.join(save_root, "predict_%d_s.png" % index)
|
||
save_path = os.path.join(save_root, "predict_%d_o.png" % index)
|
||
ground_path = os.path.join(save_root, "predict_%d_g.png" % index)
|
||
img_ground.save(ground_path)
|
||
# img_x.save(src_path)
|
||
cv2.imwrite(save_path, img_y * 255)
|
||
index = index + 1
|
||
# plt.imshow(img_y)
|
||
# plt.pause(0.5)
|
||
# plt.show()
|
||
|
||
|
||
# 计算Dice系数
|
||
def dice_calc(args):
|
||
root = './data/predict'
|
||
nums = len(os.listdir(root)) // 3
|
||
dice = list()
|
||
dice_mean = 0
|
||
for i in range(nums):
|
||
ground_path = os.path.join(root, "predict_%d_g.png" % i)
|
||
predict_path = os.path.join(root, "predict_%d_o.png" % i)
|
||
img_ground = cv2.imread(ground_path)
|
||
img_predict = cv2.imread(predict_path)
|
||
intersec = 0
|
||
x = 0
|
||
y = 0
|
||
for w in range(256):
|
||
for h in range(256):
|
||
intersec += img_ground.item(w, h, 1) * img_predict.item(w, h, 1) / (255 * 255)
|
||
x += img_ground.item(w, h, 1) / 255
|
||
y += img_predict.item(w, h, 1) / 255
|
||
if x + y == 0:
|
||
current_dice = 1
|
||
else:
|
||
current_dice = round(2 * intersec / (x + y), 3)
|
||
dice_mean += current_dice
|
||
dice.append(current_dice)
|
||
dice_mean /= len(dice)
|
||
print(dice)
|
||
print(round(dice_mean, 3))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
#参数解析
|
||
parse = argparse.ArgumentParser()
|
||
parse.add_argument("--action", type=str, help="train, test or dice", default="test")
|
||
parse.add_argument("--batch_size", type=int, default=4)
|
||
parse.add_argument("--ckpt", type=str, help="the path of model weight file", default="./model/weights_100.pth")
|
||
# parse.add_argument("--ckpt", type=str, help="the path of model weight file")
|
||
args = parse.parse_args()
|
||
|
||
if args.action == "train":
|
||
train(args)
|
||
elif args.action == "test":
|
||
test(args)
|
||
elif args.action == "dice":
|
||
dice_calc(args) |