from collections import OrderedDict import torch import numpy as np from matplotlib import pyplot as plt from scipy.signal import butter, filtfilt import pywt from models.lstm import LSTMModel class BPModel: def __init__(self, model_path, fps=30): self.fps = fps self.model = LSTMModel() self.load_model(model_path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) self.model.eval() self.warmup() def predict(self, frames): yg, g, t = self.process_frame_sequence(frames, self.fps) yg = yg.reshape(1, -1, 1) inputs = torch.tensor(yg.copy(), dtype=torch.float32) inputs = inputs.to(self.device) with torch.no_grad(): sbp_outputs, dbp_outputs = self.model(inputs) sbp_outputs = sbp_outputs.cpu().detach().numpy().item() dbp_outputs = dbp_outputs.cpu().detach().numpy().item() return sbp_outputs, dbp_outputs def load_model(self, model_path): model_state_dict = torch.load(model_path) #判断model_state_dict的类型是否是OrderedDict if not isinstance(model_state_dict, OrderedDict): # model_state_dict=model_state_dict.state_dict() #若不是OrderedDict类型,则为LSMTModel类型,直接加载 self.model = model_state_dict return #判断是否是多GPU训练的模型 if 'module' in model_state_dict.keys(): self.model.load_state_dict(model_state_dict['module']) else: #遍历模型参数,判断参数前是否有module. new_state_dict = {} for k, v in model_state_dict.items(): if 'module.' in k: name = k[7:] else: name = k new_state_dict[name] = v self.model.load_state_dict(new_state_dict) # 模型预热 def warmup(self): inputs = torch.randn(10, 250, 1) inputs = inputs.to(self.device) with torch.no_grad(): self.model(inputs) def wavelet_detrend(self, signal, wavelet='sym6', level=6): """ 小波分解和基线漂移去除 参数: signal (numpy.ndarray): 输入信号 wavelet (str): 小波基函数名称,默认为'sym6' level (int): 小波分解层数,默认为6 返回: detrended_signal (numpy.ndarray): 去除基线漂移后的信号 """ # 执行小波分解 coeffs = pywt.wavedec(signal, wavelet, level=level) # 获取第六层近似分量(基线漂移) cA6 = coeffs[0] # 重构信号,去除基线漂移 coeffs[0] = np.zeros_like(cA6) # 将基线漂移分量置为零 detrended_signal = pywt.waverec(coeffs, wavelet) return detrended_signal def butter_bandpass(self, lowcut, highcut, fs, order=5): nyq = 0.5 * fs low = lowcut / nyq high = highcut / nyq b, a = butter(order, [low, high], btype='band') return b, a def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5): b, a = self.butter_bandpass(lowcut, highcut, fs, order=order) y = filtfilt(b, a, data) return y def process_frame_sequence(self, frames, fps): """ 处理帧序列 参数: frames (list): 包含所有帧的列表,每一帧为numpy.ndarray 返回: t (list): 时间序列(秒),从0开始 yg (numpy.ndarray): 处理后的绿色通道数据 green (numpy.ndarray): 原始绿色通道数据 """ all_frames = frames green = [] for frame in all_frames: r, g, b = (frame.mean(axis=0)).mean(axis=0) green.append(g) t = [i / fps for i in range(len(all_frames))] g_detrended = self.wavelet_detrend(green) lowcut = 0.6 highcut = 8 datag = g_detrended yg = self.butter_bandpass_filter(datag, lowcut, highcut, fps, order=4) # self.plot(green, t, 'Original Green Channel',color='green') # self.plot(g_detrended, t, 'Detrended Green Channel', color='red') # self.plot(yg, t, 'Filtered Green Channel',color='blue') return yg, green, t def plot(self, yg, t,title,color='green',figsize=(30, 10)): plt.figure(figsize=figsize) plt.plot(t, yg, label=title, color=color) plt.xlabel('Time (s)') plt.ylabel('Amplitude') plt.legend() plt.show()