tjy/BloodPressure/train.py

138 lines
4.4 KiB
Python

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dataloader import BPDataLoader
from models.lstm import LSTMModel
# 定义模型
model = LSTMModel()
#定义训练参数
max_epochs = 100
batch_size= 1024
warmup_epochs = 10
lr = 0.0005
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# 定义学习率调度器
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
# 定义TensorBoard写入器
writer = SummaryWriter()
# 训练函数
def train(model, dataloader, epoch, device,batch_size):
model.train()
running_loss = 0.0
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}/{max_epochs}")
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
inputs = inputs.to(device)
sbp_labels = sbp_labels.to(device)
dbp_labels = dbp_labels.to(device)
optimizer.zero_grad()
sbp_outputs, dbp_outputs = model(inputs)
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
dbp_outputs = dbp_outputs.squeeze(1)
loss_sbp = criterion(sbp_outputs, sbp_labels)
loss_dbp = criterion(dbp_outputs, dbp_labels)
loss = loss_sbp + loss_dbp
loss.backward()
optimizer.step()
running_loss += loss.item()
pbar.set_postfix(loss=running_loss / (i + 1))
scheduler.step()
writer.add_scalar("Loss/train", running_loss / len(dataloader)/ batch_size, epoch)
return running_loss / len(dataloader) / batch_size
# 评估函数
def evaluate(model, dataloader, device,batch_size):
model.eval()
running_loss_sbp = 0.0
running_loss_dbp = 0.0
with torch.no_grad():
for inputs, sbp_labels, dbp_labels in dataloader:
inputs = inputs.to(device)
sbp_labels = sbp_labels.to(device)
dbp_labels = dbp_labels.to(device)
sbp_outputs, dbp_outputs = model(inputs)
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
dbp_outputs = dbp_outputs.squeeze(1)
loss_sbp = criterion(sbp_outputs, sbp_labels)
loss_dbp = criterion(dbp_outputs, dbp_labels)
running_loss_sbp += loss_sbp.item()
running_loss_dbp += loss_dbp.item()
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
return eval_loss_sbp, eval_loss_dbp
# 训练循环
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
data_type='MIMIC_full'
#判断权重保存目录是否存在,不存在则创建
if not os.path.exists('weights'):
os.makedirs('weights')
#在其中创建data_type同名子文件夹
os.makedirs(os.path.join('weights',data_type))
else:
#判断子文件夹是否存在
if not os.path.exists(os.path.join('weights',data_type)):
os.makedirs(os.path.join('weights',data_type))
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size,data_type=data_type)
train_dataloader, val_dataloader = data_loader.get_dataloaders()
best_val_loss_sbp = float('inf')
best_val_loss_dbp = float('inf')
for epoch in range(max_epochs):
if epoch < warmup_epochs:
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
for param_group in optimizer.param_groups:
param_group['lr'] = warmup_lr
train_loss = train(model, train_dataloader, epoch, device,batch_size)
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device,batch_size)
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
best_val_loss_sbp = val_loss_sbp
best_val_loss_dbp = val_loss_dbp
torch.save(model.state_dict(), f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
torch.save(model.state_dict(),
f'weights/{data_type}/last.pth')
writer.close()