138 lines
4.4 KiB
Python
138 lines
4.4 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
from dataloader import BPDataLoader
|
|
from models.lstm import LSTMModel
|
|
|
|
# 定义模型
|
|
model = LSTMModel()
|
|
|
|
#定义训练参数
|
|
max_epochs = 100
|
|
batch_size= 1024
|
|
warmup_epochs = 10
|
|
lr = 0.0005
|
|
|
|
# 定义损失函数和优化器
|
|
criterion = nn.MSELoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
|
|
# 定义学习率调度器
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs - warmup_epochs, eta_min=1e-6)
|
|
|
|
# 定义TensorBoard写入器
|
|
writer = SummaryWriter()
|
|
|
|
# 训练函数
|
|
def train(model, dataloader, epoch, device,batch_size):
|
|
model.train()
|
|
running_loss = 0.0
|
|
pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}/{max_epochs}")
|
|
for i, (inputs, sbp_labels, dbp_labels) in enumerate(pbar):
|
|
inputs = inputs.to(device)
|
|
sbp_labels = sbp_labels.to(device)
|
|
dbp_labels = dbp_labels.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
sbp_outputs, dbp_outputs = model(inputs)
|
|
|
|
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
|
dbp_outputs = dbp_outputs.squeeze(1)
|
|
|
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
|
|
|
loss = loss_sbp + loss_dbp
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item()
|
|
pbar.set_postfix(loss=running_loss / (i + 1))
|
|
|
|
scheduler.step()
|
|
writer.add_scalar("Loss/train", running_loss / len(dataloader)/ batch_size, epoch)
|
|
|
|
return running_loss / len(dataloader) / batch_size
|
|
|
|
# 评估函数
|
|
def evaluate(model, dataloader, device,batch_size):
|
|
model.eval()
|
|
running_loss_sbp = 0.0
|
|
running_loss_dbp = 0.0
|
|
with torch.no_grad():
|
|
for inputs, sbp_labels, dbp_labels in dataloader:
|
|
inputs = inputs.to(device)
|
|
sbp_labels = sbp_labels.to(device)
|
|
dbp_labels = dbp_labels.to(device)
|
|
|
|
sbp_outputs, dbp_outputs = model(inputs)
|
|
|
|
sbp_outputs = sbp_outputs.squeeze(1) # 将输出形状从(batch_size, 1)变为(batch_size,)
|
|
dbp_outputs = dbp_outputs.squeeze(1)
|
|
|
|
loss_sbp = criterion(sbp_outputs, sbp_labels)
|
|
loss_dbp = criterion(dbp_outputs, dbp_labels)
|
|
|
|
running_loss_sbp += loss_sbp.item()
|
|
running_loss_dbp += loss_dbp.item()
|
|
|
|
eval_loss_sbp = running_loss_sbp / len(dataloader) / batch_size
|
|
eval_loss_dbp = running_loss_dbp / len(dataloader) / batch_size
|
|
|
|
return eval_loss_sbp, eval_loss_dbp
|
|
|
|
# 训练循环
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
|
|
data_type='MIMIC_full'
|
|
|
|
#判断权重保存目录是否存在,不存在则创建
|
|
if not os.path.exists('weights'):
|
|
os.makedirs('weights')
|
|
#在其中创建data_type同名子文件夹
|
|
os.makedirs(os.path.join('weights',data_type))
|
|
else:
|
|
#判断子文件夹是否存在
|
|
if not os.path.exists(os.path.join('weights',data_type)):
|
|
os.makedirs(os.path.join('weights',data_type))
|
|
|
|
|
|
data_loader = BPDataLoader(data_dir='data', val_split=0.2, batch_size=batch_size,data_type=data_type)
|
|
|
|
train_dataloader, val_dataloader = data_loader.get_dataloaders()
|
|
|
|
best_val_loss_sbp = float('inf')
|
|
best_val_loss_dbp = float('inf')
|
|
|
|
|
|
for epoch in range(max_epochs):
|
|
if epoch < warmup_epochs:
|
|
warmup_lr = 1e-6 + (epoch + 1) * (5e-4 - 1e-6) / warmup_epochs
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = warmup_lr
|
|
|
|
train_loss = train(model, train_dataloader, epoch, device,batch_size)
|
|
val_loss_sbp, val_loss_dbp = evaluate(model, val_dataloader, device,batch_size)
|
|
|
|
writer.add_scalar("Loss/val_sbp", val_loss_sbp, epoch)
|
|
writer.add_scalar("Loss/val_dbp", val_loss_dbp, epoch)
|
|
|
|
print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss SBP: {val_loss_sbp:.4f}, Val Loss DBP: {val_loss_dbp:.4f}")
|
|
|
|
if val_loss_sbp < best_val_loss_sbp or val_loss_dbp < best_val_loss_dbp:
|
|
best_val_loss_sbp = val_loss_sbp
|
|
best_val_loss_dbp = val_loss_dbp
|
|
torch.save(model.state_dict(), f'weights/{data_type}/best_{epoch}_lstm_model_sbp{val_loss_sbp:.4f}_dbp{val_loss_dbp:.4f}.pth')
|
|
|
|
torch.save(model.state_dict(),
|
|
f'weights/{data_type}/last.pth')
|
|
|
|
writer.close() |