Bioprobes_data/python_server/wavelet_denoisy.py

67 lines
2.7 KiB
Python
Raw Permalink Normal View History

2024-07-23 16:39:18 +08:00
import os
import pandas as pd
import numpy as np
import pywt
import matplotlib.pyplot as plt
# 读取 CSV 文件,提取多列数据
def read_csv_columns(csv_file_path, header_names):
df = pd.read_csv(csv_file_path)
return df[header_names].values
# 小波变换降噪函数
def denoise_wavelet(data, wavelet='db4', level=4):
coeffs = pywt.wavedec(data, wavelet, level=level)
threshold = np.std(coeffs[-1]) * np.sqrt(2 * np.log(len(data)))
coeffs_thresh = [pywt.threshold(c, threshold, mode='soft') for c in coeffs]
return pywt.waverec(coeffs_thresh, wavelet)
# 保存处理后的数据到新的 CSV 文件
def save_to_csv(data, header_names, output_file):
df = pd.DataFrame(data, columns=header_names)
df.to_csv(output_file, index=False)
# 绘制对比图表并保存
def plot_comparison_and_save(original_data, denoised_data, header_names, save_path):
plt.figure(figsize=(12, 8))
for i, col_name in enumerate(header_names):
plt.plot(original_data[:, i], label='Original ' + col_name)
plt.plot(denoised_data[:, i], label='Denoised ' + col_name, linestyle='--', linewidth=2)
plt.title('Wavelet Denoising Comparison')
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend(loc='lower left')
plt.tight_layout()
plt.savefig(save_path) # 保存图表为文件
def processALL_Wavalet(csv_file_directory, output_file_directory, output_png_direcotry, columns_to_process):
for filename in os.listdir(csv_file_directory):
if filename.endswith('.csv'):
csv_file_path = os.path.join(csv_file_directory, filename)
save_path = os.path.join(output_file_directory, filename.split('.')[0] + '_wavelet_denoised.csv')
output_plot_file = os.path.join(output_png_direcotry, filename.split('.')[0] + '_wavelet_denoised.png')
data = read_csv_columns(csv_file_path, columns_to_process)
denoised_data = []
for col in data.T: # 转置以便对每列进行处理
denoised_col = denoise_wavelet(col)
denoised_data.append(denoised_col)
save_to_csv(np.array(denoised_data).T, columns_to_process, save_path)
plot_comparison_and_save(data, np.array(denoised_data).T, columns_to_process, output_plot_file)
if __name__ == "__main__":
input_csv_directory = r'D:\pycharmProjects\python_server\data'
output_csv_directory = r'D:\pycharmProjects\python_server\Wavelet_DenoisedData'
output_plot_directory = r'D:\pycharmProjects\python_server\Wavelet_DenoisedPng'
# 要处理的列名列表
columns_to_process = ['accX', 'accY', 'accZ', 'gyroX', 'gyroY', 'gyroZ', 'magX', 'magY', 'magZ']
processALL_Wavalet(input_csv_directory, output_csv_directory, output_plot_directory, columns_to_process)