152 lines
4.7 KiB
Python
152 lines
4.7 KiB
Python
import os
|
||
import sys
|
||
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
sys.path.insert(0, current_dir)
|
||
|
||
from collections import OrderedDict
|
||
import torch
|
||
import numpy as np
|
||
from matplotlib import pyplot as plt
|
||
from scipy.signal import butter, filtfilt
|
||
import pywt
|
||
from .models.lstm import LSTMModel
|
||
|
||
|
||
class BPModel:
|
||
def __init__(self, model_path, fps=30):
|
||
self.fps = fps
|
||
|
||
self.model = LSTMModel()
|
||
|
||
self.load_model(model_path)
|
||
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
self.model = self.model.to(self.device)
|
||
self.model.eval()
|
||
self.warmup()
|
||
|
||
def predict(self, frames):
|
||
|
||
yg, g, t = self.process_frame_sequence(frames, self.fps)
|
||
|
||
yg = yg.reshape(1, -1, 1)
|
||
inputs = torch.tensor(yg.copy(), dtype=torch.float32)
|
||
inputs = inputs.to(self.device)
|
||
with torch.no_grad():
|
||
sbp_outputs, dbp_outputs = self.model(inputs)
|
||
sbp_outputs = sbp_outputs.cpu().detach().numpy().item()
|
||
dbp_outputs = dbp_outputs.cpu().detach().numpy().item()
|
||
|
||
return sbp_outputs, dbp_outputs
|
||
|
||
def load_model(self, model_path):
|
||
|
||
model_state_dict = torch.load(model_path)
|
||
|
||
# 判断model_state_dict的类型是否是OrderedDict
|
||
if not isinstance(model_state_dict, OrderedDict):
|
||
# model_state_dict=model_state_dict.state_dict()
|
||
# 若不是OrderedDict类型,则为LSMTModel类型,直接加载
|
||
self.model = model_state_dict
|
||
return
|
||
|
||
# 判断是否是多GPU训练的模型
|
||
if 'module' in model_state_dict.keys():
|
||
self.model.load_state_dict(model_state_dict['module'])
|
||
else:
|
||
# 遍历模型参数,判断参数前是否有module.
|
||
new_state_dict = {}
|
||
for k, v in model_state_dict.items():
|
||
if 'module.' in k:
|
||
name = k[7:]
|
||
else:
|
||
name = k
|
||
new_state_dict[name] = v
|
||
self.model.load_state_dict(new_state_dict)
|
||
|
||
# 模型预热
|
||
def warmup(self):
|
||
inputs = torch.randn(10, 250, 1)
|
||
inputs = inputs.to(self.device)
|
||
with torch.no_grad():
|
||
self.model(inputs)
|
||
|
||
def wavelet_detrend(self, signal, wavelet='sym6', level=6):
|
||
"""
|
||
小波分解和基线漂移去除
|
||
|
||
参数:
|
||
signal (numpy.ndarray): 输入信号
|
||
wavelet (str): 小波基函数名称,默认为'sym6'
|
||
level (int): 小波分解层数,默认为6
|
||
|
||
返回:
|
||
detrended_signal (numpy.ndarray): 去除基线漂移后的信号
|
||
"""
|
||
# 执行小波分解
|
||
coeffs = pywt.wavedec(signal, wavelet, level=level)
|
||
|
||
# 获取第六层近似分量(基线漂移)
|
||
cA6 = coeffs[0]
|
||
|
||
# 重构信号,去除基线漂移
|
||
coeffs[0] = np.zeros_like(cA6) # 将基线漂移分量置为零
|
||
detrended_signal = pywt.waverec(coeffs, wavelet)
|
||
|
||
return detrended_signal
|
||
|
||
def butter_bandpass(self, lowcut, highcut, fs, order=5):
|
||
nyq = 0.5 * fs
|
||
low = lowcut / nyq
|
||
high = highcut / nyq
|
||
b, a = butter(order, [low, high], btype='band')
|
||
return b, a
|
||
|
||
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
|
||
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
|
||
y = filtfilt(b, a, data)
|
||
return y
|
||
|
||
def process_frame_sequence(self, frames, fps):
|
||
"""
|
||
处理帧序列
|
||
|
||
参数:
|
||
frames (list): 包含所有帧的列表,每一帧为numpy.ndarray
|
||
|
||
返回:
|
||
t (list): 时间序列(秒),从0开始
|
||
yg (numpy.ndarray): 处理后的绿色通道数据
|
||
green (numpy.ndarray): 原始绿色通道数据
|
||
"""
|
||
all_frames = frames
|
||
|
||
green = []
|
||
for frame in all_frames:
|
||
r, g, b = (frame.mean(axis=0)).mean(axis=0)
|
||
green.append(g)
|
||
|
||
t = [i / fps for i in range(len(all_frames))]
|
||
|
||
g_detrended = self.wavelet_detrend(green)
|
||
lowcut = 0.6
|
||
highcut = 8
|
||
datag = g_detrended
|
||
yg = self.butter_bandpass_filter(datag, lowcut, highcut, fps, order=4)
|
||
|
||
# self.plot(green, t, 'Original Green Channel',color='green')
|
||
# self.plot(g_detrended, t, 'Detrended Green Channel', color='red')
|
||
# self.plot(yg, t, 'Filtered Green Channel',color='blue')
|
||
|
||
return yg, green, t
|
||
|
||
def plot(self, yg, t, title, color='green', figsize=(30, 10)):
|
||
plt.figure(figsize=figsize)
|
||
plt.plot(t, yg, label=title, color=color)
|
||
plt.xlabel('Time (s)')
|
||
plt.ylabel('Amplitude')
|
||
plt.legend()
|
||
plt.show()
|