67 lines
2.7 KiB
Python
67 lines
2.7 KiB
Python
|
import os
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
import pywt
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
|
||
|
# 读取 CSV 文件,提取多列数据
|
||
|
def read_csv_columns(csv_file_path, header_names):
|
||
|
df = pd.read_csv(csv_file_path)
|
||
|
return df[header_names].values
|
||
|
|
||
|
|
||
|
# 小波变换降噪函数
|
||
|
def denoise_wavelet(data, wavelet='db4', level=4):
|
||
|
coeffs = pywt.wavedec(data, wavelet, level=level)
|
||
|
threshold = np.std(coeffs[-1]) * np.sqrt(2 * np.log(len(data)))
|
||
|
coeffs_thresh = [pywt.threshold(c, threshold, mode='soft') for c in coeffs]
|
||
|
return pywt.waverec(coeffs_thresh, wavelet)
|
||
|
|
||
|
|
||
|
# 保存处理后的数据到新的 CSV 文件
|
||
|
def save_to_csv(data, header_names, output_file):
|
||
|
df = pd.DataFrame(data, columns=header_names)
|
||
|
df.to_csv(output_file, index=False)
|
||
|
|
||
|
|
||
|
# 绘制对比图表并保存
|
||
|
def plot_comparison_and_save(original_data, denoised_data, header_names, save_path):
|
||
|
plt.figure(figsize=(12, 8))
|
||
|
for i, col_name in enumerate(header_names):
|
||
|
plt.plot(original_data[:, i], label='Original ' + col_name)
|
||
|
plt.plot(denoised_data[:, i], label='Denoised ' + col_name, linestyle='--', linewidth=2)
|
||
|
plt.title('Wavelet Denoising Comparison')
|
||
|
plt.xlabel('Time')
|
||
|
plt.ylabel('Value')
|
||
|
plt.legend(loc='lower left')
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(save_path) # 保存图表为文件
|
||
|
|
||
|
|
||
|
def processALL_Wavalet(csv_file_directory, output_file_directory, output_png_direcotry, columns_to_process):
|
||
|
for filename in os.listdir(csv_file_directory):
|
||
|
if filename.endswith('.csv'):
|
||
|
csv_file_path = os.path.join(csv_file_directory, filename)
|
||
|
save_path = os.path.join(output_file_directory, filename.split('.')[0] + '_wavelet_denoised.csv')
|
||
|
output_plot_file = os.path.join(output_png_direcotry, filename.split('.')[0] + '_wavelet_denoised.png')
|
||
|
data = read_csv_columns(csv_file_path, columns_to_process)
|
||
|
denoised_data = []
|
||
|
for col in data.T: # 转置以便对每列进行处理
|
||
|
denoised_col = denoise_wavelet(col)
|
||
|
denoised_data.append(denoised_col)
|
||
|
save_to_csv(np.array(denoised_data).T, columns_to_process, save_path)
|
||
|
plot_comparison_and_save(data, np.array(denoised_data).T, columns_to_process, output_plot_file)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
input_csv_directory = r'D:\pycharmProjects\python_server\data'
|
||
|
output_csv_directory = r'D:\pycharmProjects\python_server\Wavelet_DenoisedData'
|
||
|
output_plot_directory = r'D:\pycharmProjects\python_server\Wavelet_DenoisedPng'
|
||
|
|
||
|
# 要处理的列名列表
|
||
|
columns_to_process = ['accX', 'accY', 'accZ', 'gyroX', 'gyroY', 'gyroZ', 'magX', 'magY', 'magZ']
|
||
|
processALL_Wavalet(input_csv_directory, output_csv_directory, output_plot_directory, columns_to_process)
|
||
|
|
||
|
|