tjy/BloodPressure/BPApi.py

150 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from collections import OrderedDict
import torch
import numpy as np
from matplotlib import pyplot as plt
from scipy.signal import butter, filtfilt
import pywt
from models.lstm import LSTMModel
class BPModel:
def __init__(self, model_path, fps=30):
self.fps = fps
self.model = LSTMModel()
self.load_model(model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.model.eval()
self.warmup()
def predict(self, frames):
yg, g, t = self.process_frame_sequence(frames, self.fps)
yg = yg.reshape(1, -1, 1)
inputs = torch.tensor(yg.copy(), dtype=torch.float32)
inputs = inputs.to(self.device)
with torch.no_grad():
sbp_outputs, dbp_outputs = self.model(inputs)
sbp_outputs = sbp_outputs.cpu().detach().numpy().item()
dbp_outputs = dbp_outputs.cpu().detach().numpy().item()
return sbp_outputs, dbp_outputs
def load_model(self, model_path):
model_state_dict = torch.load(model_path)
#判断model_state_dict的类型是否是OrderedDict
if not isinstance(model_state_dict, OrderedDict):
# model_state_dict=model_state_dict.state_dict()
#若不是OrderedDict类型则为LSMTModel类型直接加载
self.model = model_state_dict
return
#判断是否是多GPU训练的模型
if 'module' in model_state_dict.keys():
self.model.load_state_dict(model_state_dict['module'])
else:
#遍历模型参数判断参数前是否有module.
new_state_dict = {}
for k, v in model_state_dict.items():
if 'module.' in k:
name = k[7:]
else:
name = k
new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)
# 模型预热
def warmup(self):
inputs = torch.randn(10, 250, 1)
inputs = inputs.to(self.device)
with torch.no_grad():
self.model(inputs)
def wavelet_detrend(self, signal, wavelet='sym6', level=6):
"""
小波分解和基线漂移去除
参数:
signal (numpy.ndarray): 输入信号
wavelet (str): 小波基函数名称,默认为'sym6'
level (int): 小波分解层数,默认为6
返回:
detrended_signal (numpy.ndarray): 去除基线漂移后的信号
"""
# 执行小波分解
coeffs = pywt.wavedec(signal, wavelet, level=level)
# 获取第六层近似分量(基线漂移)
cA6 = coeffs[0]
# 重构信号,去除基线漂移
coeffs[0] = np.zeros_like(cA6) # 将基线漂移分量置为零
detrended_signal = pywt.waverec(coeffs, wavelet)
return detrended_signal
def butter_bandpass(self, lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band')
return b, a
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
y = filtfilt(b, a, data)
return y
def process_frame_sequence(self, frames, fps):
"""
处理帧序列
参数:
frames (list): 包含所有帧的列表,每一帧为numpy.ndarray
返回:
t (list): 时间序列(秒),从0开始
yg (numpy.ndarray): 处理后的绿色通道数据
green (numpy.ndarray): 原始绿色通道数据
"""
all_frames = frames
green = []
for frame in all_frames:
r, g, b = (frame.mean(axis=0)).mean(axis=0)
green.append(g)
t = [i / fps for i in range(len(all_frames))]
g_detrended = self.wavelet_detrend(green)
lowcut = 0.6
highcut = 8
datag = g_detrended
yg = self.butter_bandpass_filter(datag, lowcut, highcut, fps, order=4)
# self.plot(green, t, 'Original Green Channel',color='green')
# self.plot(g_detrended, t, 'Detrended Green Channel', color='red')
# self.plot(yg, t, 'Filtered Green Channel',color='blue')
return yg, green, t
def plot(self, yg, t,title,color='green',figsize=(30, 10)):
plt.figure(figsize=figsize)
plt.plot(t, yg, label=title, color=color)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()