tjy/BloodPressure/train.py

138 lines
4.4 KiB
Python
Raw Normal View History

2024-06-20 18:22:33 +08:00
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dataloader import BPDataLoader
from models.lstm import LSTMModel
# 定义模型
model = LSTMModel()
#定义训练参数
max_epochs = 100
batch_size= 1024
warmup_epochs = 10
lr = 0.0005
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# 定义学习率调度器
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
# 定义TensorBoard写入器
writer = SummaryWriter()
# 训练函数
def train(model, dataloader, epoch, device,batch_size):
model.train()
running_loss = 0.0
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}/{max_epochs}")
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
inputs = inputs.to(device)
sbp_labels = sbp_labels.to(device)
dbp_labels = dbp_labels.to(device)
optimizer.zero_grad()
sbp_outputs, dbp_outputs = model(inputs)
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
dbp_outputs = dbp_outputs.squeeze(1)
loss_sbp = criterion(sbp_outputs, sbp_labels)
loss_dbp = criterion(dbp_outputs, dbp_labels)
loss = loss_sbp + loss_dbp
loss.backward()
optimizer.step()
running_loss += loss.item()
pbar.set_postfix(loss=running_loss / (i + 1))
scheduler.step()
writer.add_scalar("Loss/train", running_loss / len(dataloader)/ batch_size, epoch)
return running_loss / len(dataloader) / batch_size
# 评估函数
def evaluate(model, dataloader, device,batch_size):
model.eval()
running_loss_sbp = 0.0
running_loss_dbp = 0.0
with torch.no_grad():
for inputs, sbp_labels, dbp_labels in dataloader:
inputs = inputs.to(device)
sbp_labels = sbp_labels.to(device)
dbp_labels = dbp_labels.to(device)
sbp_outputs, dbp_outputs = model(inputs)
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
dbp_outputs = dbp_outputs.squeeze(1)
loss_sbp = criterion(sbp_outputs, sbp_labels)
loss_dbp = criterion(dbp_outputs, dbp_labels)
running_loss_sbp += loss_sbp.item()
running_loss_dbp += loss_dbp.item()
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
return eval_loss_sbp, eval_loss_dbp
# 训练循环
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
data_type='MIMIC_full'
#判断权重保存目录是否存在,不存在则创建
if not os.path.exists('weights'):
os.makedirs('weights')
#在其中创建data_type同名子文件夹
os.makedirs(os.path.join('weights',data_type))
else:
#判断子文件夹是否存在
if not os.path.exists(os.path.join('weights',data_type)):
os.makedirs(os.path.join('weights',data_type))
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size,data_type=data_type)
train_dataloader, val_dataloader = data_loader.get_dataloaders()
best_val_loss_sbp = float('inf')
best_val_loss_dbp = float('inf')
for epoch in range(max_epochs):
if epoch < warmup_epochs:
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
for param_group in optimizer.param_groups:
param_group['lr'] = warmup_lr
train_loss = train(model, train_dataloader, epoch, device,batch_size)
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device,batch_size)
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
best_val_loss_sbp = val_loss_sbp
best_val_loss_dbp = val_loss_dbp
torch.save(model.state_dict(), f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
torch.save(model.state_dict(),
f'weights/{data_type}/last.pth')
writer.close()