tjy/BloodPressure/BPApi.py

150 lines
4.5 KiB
Python
Raw Normal View History

2024-06-20 18:22:33 +08:00
from collections import OrderedDict
import torch
import numpy as np
from matplotlib import pyplot as plt
from scipy.signal import butter, filtfilt
import pywt
from models.lstm import LSTMModel
class BPModel:
def __init__(self, model_path, fps=30):
self.fps = fps
self.model = LSTMModel()
self.load_model(model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.model.eval()
self.warmup()
def predict(self, frames):
yg, g, t = self.process_frame_sequence(frames, self.fps)
yg = yg.reshape(1, -1, 1)
inputs = torch.tensor(yg.copy(), dtype=torch.float32)
inputs = inputs.to(self.device)
with torch.no_grad():
sbp_outputs, dbp_outputs = self.model(inputs)
sbp_outputs = sbp_outputs.cpu().detach().numpy().item()
dbp_outputs = dbp_outputs.cpu().detach().numpy().item()
return sbp_outputs, dbp_outputs
def load_model(self, model_path):
model_state_dict = torch.load(model_path)
#判断model_state_dict的类型是否是OrderedDict
if not isinstance(model_state_dict, OrderedDict):
# model_state_dict=model_state_dict.state_dict()
#若不是OrderedDict类型则为LSMTModel类型直接加载
self.model = model_state_dict
return
#判断是否是多GPU训练的模型
if 'module' in model_state_dict.keys():
self.model.load_state_dict(model_state_dict['module'])
else:
#遍历模型参数判断参数前是否有module.
new_state_dict = {}
for k, v in model_state_dict.items():
if 'module.' in k:
name = k[7:]
else:
name = k
new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)
# 模型预热
def warmup(self):
inputs = torch.randn(10, 250, 1)
inputs = inputs.to(self.device)
with torch.no_grad():
self.model(inputs)
def wavelet_detrend(self, signal, wavelet='sym6', level=6):
"""
小波分解和基线漂移去除
参数:
signal (numpy.ndarray): 输入信号
wavelet (str): 小波基函数名称,默认为'sym6'
level (int): 小波分解层数,默认为6
返回:
detrended_signal (numpy.ndarray): 去除基线漂移后的信号
"""
# 执行小波分解
coeffs = pywt.wavedec(signal, wavelet, level=level)
# 获取第六层近似分量(基线漂移)
cA6 = coeffs[0]
# 重构信号,去除基线漂移
coeffs[0] = np.zeros_like(cA6) # 将基线漂移分量置为零
detrended_signal = pywt.waverec(coeffs, wavelet)
return detrended_signal
def butter_bandpass(self, lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band')
return b, a
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
y = filtfilt(b, a, data)
return y
def process_frame_sequence(self, frames, fps):
"""
处理帧序列
参数:
frames (list): 包含所有帧的列表,每一帧为numpy.ndarray
返回:
t (list): 时间序列(),从0开始
yg (numpy.ndarray): 处理后的绿色通道数据
green (numpy.ndarray): 原始绿色通道数据
"""
all_frames = frames
green = []
for frame in all_frames:
r, g, b = (frame.mean(axis=0)).mean(axis=0)
green.append(g)
t = [i / fps for i in range(len(all_frames))]
g_detrended = self.wavelet_detrend(green)
lowcut = 0.6
highcut = 8
datag = g_detrended
yg = self.butter_bandpass_filter(datag, lowcut, highcut, fps, order=4)
# self.plot(green, t, 'Original Green Channel',color='green')
# self.plot(g_detrended, t, 'Detrended Green Channel', color='red')
# self.plot(yg, t, 'Filtered Green Channel',color='blue')
return yg, green, t
def plot(self, yg, t,title,color='green',figsize=(30, 10)):
plt.figure(figsize=figsize)
plt.plot(t, yg, label=title, color=color)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()