150 lines
4.5 KiB
Python
150 lines
4.5 KiB
Python
|
from collections import OrderedDict
|
|||
|
|
|||
|
import torch
|
|||
|
import numpy as np
|
|||
|
from matplotlib import pyplot as plt
|
|||
|
from scipy.signal import butter, filtfilt
|
|||
|
import pywt
|
|||
|
from models.lstm import LSTMModel
|
|||
|
|
|||
|
|
|||
|
class BPModel:
|
|||
|
def __init__(self, model_path, fps=30):
|
|||
|
self.fps = fps
|
|||
|
|
|||
|
self.model = LSTMModel()
|
|||
|
|
|||
|
self.load_model(model_path)
|
|||
|
|
|||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
|
|||
|
self.model = self.model.to(self.device)
|
|||
|
self.model.eval()
|
|||
|
self.warmup()
|
|||
|
|
|||
|
def predict(self, frames):
|
|||
|
yg, g, t = self.process_frame_sequence(frames, self.fps)
|
|||
|
yg = yg.reshape(1, -1, 1)
|
|||
|
inputs = torch.tensor(yg.copy(), dtype=torch.float32)
|
|||
|
inputs = inputs.to(self.device)
|
|||
|
with torch.no_grad():
|
|||
|
sbp_outputs, dbp_outputs = self.model(inputs)
|
|||
|
sbp_outputs = sbp_outputs.cpu().detach().numpy().item()
|
|||
|
dbp_outputs = dbp_outputs.cpu().detach().numpy().item()
|
|||
|
return sbp_outputs, dbp_outputs
|
|||
|
|
|||
|
def load_model(self, model_path):
|
|||
|
|
|||
|
|
|||
|
model_state_dict = torch.load(model_path)
|
|||
|
|
|||
|
#判断model_state_dict的类型是否是OrderedDict
|
|||
|
if not isinstance(model_state_dict, OrderedDict):
|
|||
|
# model_state_dict=model_state_dict.state_dict()
|
|||
|
#若不是OrderedDict类型,则为LSMTModel类型,直接加载
|
|||
|
self.model = model_state_dict
|
|||
|
return
|
|||
|
|
|||
|
#判断是否是多GPU训练的模型
|
|||
|
if 'module' in model_state_dict.keys():
|
|||
|
self.model.load_state_dict(model_state_dict['module'])
|
|||
|
else:
|
|||
|
#遍历模型参数,判断参数前是否有module.
|
|||
|
new_state_dict = {}
|
|||
|
for k, v in model_state_dict.items():
|
|||
|
if 'module.' in k:
|
|||
|
name = k[7:]
|
|||
|
else:
|
|||
|
name = k
|
|||
|
new_state_dict[name] = v
|
|||
|
self.model.load_state_dict(new_state_dict)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
# 模型预热
|
|||
|
def warmup(self):
|
|||
|
inputs = torch.randn(10, 250, 1)
|
|||
|
inputs = inputs.to(self.device)
|
|||
|
with torch.no_grad():
|
|||
|
self.model(inputs)
|
|||
|
|
|||
|
def wavelet_detrend(self, signal, wavelet='sym6', level=6):
|
|||
|
"""
|
|||
|
小波分解和基线漂移去除
|
|||
|
|
|||
|
参数:
|
|||
|
signal (numpy.ndarray): 输入信号
|
|||
|
wavelet (str): 小波基函数名称,默认为'sym6'
|
|||
|
level (int): 小波分解层数,默认为6
|
|||
|
|
|||
|
返回:
|
|||
|
detrended_signal (numpy.ndarray): 去除基线漂移后的信号
|
|||
|
"""
|
|||
|
# 执行小波分解
|
|||
|
coeffs = pywt.wavedec(signal, wavelet, level=level)
|
|||
|
|
|||
|
# 获取第六层近似分量(基线漂移)
|
|||
|
cA6 = coeffs[0]
|
|||
|
|
|||
|
# 重构信号,去除基线漂移
|
|||
|
coeffs[0] = np.zeros_like(cA6) # 将基线漂移分量置为零
|
|||
|
detrended_signal = pywt.waverec(coeffs, wavelet)
|
|||
|
|
|||
|
return detrended_signal
|
|||
|
|
|||
|
def butter_bandpass(self, lowcut, highcut, fs, order=5):
|
|||
|
nyq = 0.5 * fs
|
|||
|
low = lowcut / nyq
|
|||
|
high = highcut / nyq
|
|||
|
b, a = butter(order, [low, high], btype='band')
|
|||
|
return b, a
|
|||
|
|
|||
|
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
|
|||
|
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
|
|||
|
y = filtfilt(b, a, data)
|
|||
|
return y
|
|||
|
|
|||
|
def process_frame_sequence(self, frames, fps):
|
|||
|
"""
|
|||
|
处理帧序列
|
|||
|
|
|||
|
参数:
|
|||
|
frames (list): 包含所有帧的列表,每一帧为numpy.ndarray
|
|||
|
|
|||
|
返回:
|
|||
|
t (list): 时间序列(秒),从0开始
|
|||
|
yg (numpy.ndarray): 处理后的绿色通道数据
|
|||
|
green (numpy.ndarray): 原始绿色通道数据
|
|||
|
"""
|
|||
|
all_frames = frames
|
|||
|
|
|||
|
green = []
|
|||
|
for frame in all_frames:
|
|||
|
r, g, b = (frame.mean(axis=0)).mean(axis=0)
|
|||
|
green.append(g)
|
|||
|
|
|||
|
t = [i / fps for i in range(len(all_frames))]
|
|||
|
|
|||
|
g_detrended = self.wavelet_detrend(green)
|
|||
|
lowcut = 0.6
|
|||
|
highcut = 8
|
|||
|
datag = g_detrended
|
|||
|
yg = self.butter_bandpass_filter(datag, lowcut, highcut, fps, order=4)
|
|||
|
|
|||
|
# self.plot(green, t, 'Original Green Channel',color='green')
|
|||
|
# self.plot(g_detrended, t, 'Detrended Green Channel', color='red')
|
|||
|
# self.plot(yg, t, 'Filtered Green Channel',color='blue')
|
|||
|
|
|||
|
|
|||
|
return yg, green, t
|
|||
|
|
|||
|
def plot(self, yg, t,title,color='green',figsize=(30, 10)):
|
|||
|
plt.figure(figsize=figsize)
|
|||
|
plt.plot(t, yg, label=title, color=color)
|
|||
|
plt.xlabel('Time (s)')
|
|||
|
plt.ylabel('Amplitude')
|
|||
|
plt.legend()
|
|||
|
plt.show()
|
|||
|
|
|||
|
|