# Savitzky-Golay滤波的隐藏玩法:用Python实现实时数据流平滑处理
在金融量化交易中,高频价格数据总是伴随着市场噪音,那些微小的随机波动常常掩盖了真正的趋势信号。生物信号处理领域的研究者同样头疼,无论是脑电图还是心电图,采集到的原始信号总会被各种生理和环境噪声干扰。传统的离线滤波方法虽然能处理完整的数据集,但面对源源不断的实时数据流时,就显得力不从心了——你总不能等到收盘后才去分析当天的交易策略是否有效,也不能在患者监测时等数据全部录完再判断心率是否异常。
Savitzky-Golay滤波器(简称SG滤波)自1964年诞生以来,一直是光谱分析和信号处理领域的经典工具。大多数教程都停留在`scipy.signal.savgol_filter`的基本调用上,演示如何对静态数组进行平滑。但真正的挑战在于动态环境:当数据像流水一样不断涌来时,如何用SG滤波实现**低延迟、高效率的实时处理**?这需要跳出`scipy`的舒适区,深入算法的核心,并巧妙结合现代Python的数据处理技巧。
今天我要分享的,正是SG滤波在实时场景下的“隐藏玩法”。我们将绕过`scipy`的封装,从原理层面实现一个**滑动窗口实时SG滤波器**,结合`numpy.lib.stride_tricks.sliding_window_view`来高效处理数据流。我还会分享在TCP传输实验数据时如何优化延迟,以及在Jupyter Notebook中实现**实时可视化调试**的实用技巧。这些方法我在实际的量化策略开发和生物信号处理项目中反复打磨过,效果比简单的离线滤波要强得多。
## 1. 理解SG滤波的实时化核心挑战
SG滤波的本质是在一个固定长度的滑动窗口内,用多项式对数据进行最小二乘拟合,然后用拟合出的多项式在窗口中心点的值作为平滑输出。这个算法天然适合滑动窗口处理,但`scipy`的实现是为完整数组设计的,每次调用都会重新计算整个序列,这在实时场景下会造成大量冗余计算。
假设你有一个实时数据源,每秒产生100个采样点。如果你每收到一个新点就调用一次`savgol_filter`处理整个历史数据,计算复杂度会随着时间线性增长,很快系统就会不堪重负。更糟糕的是,这种处理方式会引入越来越大的延迟——你需要等待足够长的历史数据才能开始滤波,而且每次处理都要重新计算所有窗口。
**实时SG滤波的关键在于增量计算**。我们只需要维护一个固定长度的数据缓冲区,每来一个新数据,就移除最旧的一个,然后在新的窗口上计算多项式拟合。听起来简单,但实现时有几个陷阱:
- **边界效应**:窗口边缘的数据点由于邻居不足,拟合精度会下降
- **计算效率**:每次窗口滑动都重新计算最小二乘拟合,计算量依然可观
- **内存管理**:需要高效的数据结构来支持快速的数据进出
`numpy.lib.stride_tricks.sliding_window_view` 是解决这些问题的利器。这个函数可以创建数组的滑动窗口视图,而不复制数据,为我们实现高效的实时滤波提供了基础。
> 注意:`sliding_window_view` 在NumPy 1.20.0中引入,如果你使用的是旧版本,需要先升级。在金融交易系统中,版本兼容性需要提前测试。
让我们先看看传统SG滤波和实时SG滤波在处理方式上的根本区别:
| 特性 | 传统SG滤波(scipy实现) | 实时SG滤波(本文方法) |
|------|------------------------|------------------------|
| 数据处理模式 | 批量处理完整数组 | 流式处理,逐点更新 |
| 内存使用 | 需要存储完整历史数据 | 只需维护固定长度窗口 |
| 计算复杂度 | O(n·w),n为数据长度,w为窗口长度 | O(w²),每次更新只计算一个窗口 |
| 延迟 | 必须等待完整数据 | 固定延迟(窗口长度的一半) |
| 适用场景 | 离线分析、事后处理 | 实时监控、在线交易、实时监测 |
从表中可以看出,实时SG滤波在延迟和内存使用上有明显优势,特别适合需要快速响应的场景。
## 2. 从零构建实时SG滤波引擎
要实现高效的实时SG滤波,我们需要深入算法数学本质。SG滤波的核心是求解最小二乘问题:对于窗口内的数据点$(x_i, y_i)$,用$k$阶多项式$p(x) = a_0 + a_1x + a_2x^2 + ... + a_kx^k$进行拟合,然后取$x=0$(窗口中心)时的值作为输出。
这个计算可以表示为矩阵运算:$\mathbf{y} = \mathbf{X}\mathbf{a}$,其中$\mathbf{X}$是范德蒙矩阵。通过最小二乘法求解系数$\mathbf{a} = (\mathbf{X}^T\mathbf{X})^{-1}\mathbf{X}^T\mathbf{y}$。对于实时应用,关键是要预计算**卷积系数**,这样每次只需要做点积运算。
```python
import numpy as np
from scipy.linalg import lstsq
def compute_sg_coefficients(window_length, polyorder, deriv=0):
"""
预计算Savitzky-Golay滤波器的卷积系数
参数:
window_length: 窗口长度,必须为奇数
polyorder: 多项式阶数
deriv: 导数阶数,0表示平滑
返回:
coefficients: 卷积系数数组
"""
if window_length % 2 == 0:
raise ValueError("window_length必须是奇数")
if polyorder >= window_length:
raise ValueError("polyorder必须小于window_length")
# 生成x坐标,中心点为0
half_window = window_length // 2
x = np.arange(-half_window, half_window + 1)
# 构建范德蒙矩阵
X = np.vander(x, polyorder + 1, increasing=True)
# 计算伪逆矩阵的第一行(对应中心点的系数)
# 对于导数计算,需要乘以适当的阶乘
if deriv == 0:
target = np.zeros(polyorder + 1)
target[0] = 1 # 取常数项
else:
target = np.zeros(polyorder + 1)
target[deriv] = np.math.factorial(deriv)
# 最小二乘求解
coefficients, _, _, _ = lstsq(X, target)
return coefficients
# 示例:计算窗口长度11、3阶多项式的SG系数
coeffs = compute_sg_coefficients(11, 3)
print(f"SG滤波系数: {coeffs}")
```
有了预计算的系数,实时滤波就变成了简单的卷积运算。但这里有个关键点:对于实时数据流,我们只需要在窗口中心点输出一个值,而不是整个窗口的平滑结果。这意味着我们可以实现一个**增量更新**的滤波器。
```python
class RealTimeSGFilter:
"""实时Savitzky-Golay滤波器"""
def __init__(self, window_length=11, polyorder=3):
"""
初始化实时SG滤波器
参数:
window_length: 窗口长度(奇数)
polyorder: 多项式阶数
"""
self.window_length = window_length
self.polyorder = polyorder
self.half_window = window_length // 2
# 预计算卷积系数
self.coefficients = compute_sg_coefficients(window_length, polyorder)
# 数据缓冲区
self.buffer = np.zeros(window_length)
self.buffer_index = 0
self.buffer_full = False
# 输出延迟:需要等待窗口填满一半
self.output_ready = False
self.delay_counter = 0
def update(self, new_value):
"""
更新滤波器状态并返回平滑值(如果可用)
参数:
new_value: 新的数据点
返回:
smoothed_value: 平滑后的值(如果可用),否则返回None
"""
# 更新缓冲区(环形缓冲区实现)
self.buffer[self.buffer_index] = new_value
self.buffer_index = (self.buffer_index + 1) % self.window_length
# 检查缓冲区是否已填满
if not self.buffer_full and self.buffer_index == 0:
self.buffer_full = True
# 更新延迟计数器
if not self.output_ready:
self.delay_counter += 1
if self.delay_counter > self.half_window:
self.output_ready = True
# 如果输出就绪且缓冲区已满,计算平滑值
if self.output_ready and self.buffer_full:
# 重新排列缓冲区,使当前点位于窗口中心
rearranged = np.zeros(self.window_length)
for i in range(self.window_length):
idx = (self.buffer_index - self.half_window - 1 + i) % self.window_length
rearranged[i] = self.buffer[idx]
# 计算卷积(点积)
smoothed = np.dot(rearranged, self.coefficients)
return smoothed
return None
def reset(self):
"""重置滤波器状态"""
self.buffer.fill(0)
self.buffer_index = 0
self.buffer_full = False
self.output_ready = False
self.delay_counter = 0
# 使用示例
filter_rt = RealTimeSGFilter(window_length=11, polyorder=3)
# 模拟实时数据流
data_stream = np.random.randn(100) + np.sin(np.linspace(0, 4*np.pi, 100))
smoothed_values = []
for value in data_stream:
smoothed = filter_rt.update(value)
if smoothed is not None:
smoothed_values.append(smoothed)
print(f"原始数据长度: {len(data_stream)}")
print(f"平滑后数据长度: {len(smoothed_values)}")
print(f"滤波器延迟: {filter_rt.half_window} 个采样点")
```
这个实现的关键优势在于**计算效率**。预计算系数后,每个新数据点只需要一次点积运算(O(n)复杂度),而不是重新计算整个最小二乘拟合(O(n³)复杂度)。在金融高频交易中,这种效率提升是至关重要的。
## 3. 结合sliding_window_view实现高效批处理
虽然上面的实时滤波器适合逐点处理,但在某些场景下,我们可能希望以小批量的方式处理数据。`numpy.lib.stride_tricks.sliding_window_view` 提供了一个极其高效的方法来创建滑动窗口视图,而不需要复制数据。
```python
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
def batch_sg_filter(data, window_length, polyorder, step=1):
"""
使用sliding_window_view进行批量SG滤波
参数:
data: 输入数据数组
window_length: 窗口长度
polyorder: 多项式阶数
step: 滑动步长(默认为1)
返回:
smoothed: 平滑后的数据
"""
# 预计算SG系数
coeffs = compute_sg_coefficients(window_length, polyorder)
# 创建滑动窗口视图
# 注意:sliding_window_view返回的是原始数据的视图,不复制数据
windows = sliding_window_view(data, window_length)[::step]
# 对每个窗口应用SG滤波(中心点)
# 使用向量化操作提高效率
smoothed = np.dot(windows, coeffs)
# 处理边界:前后各填充 (window_length//2) 个NaN
pad_width = window_length // 2
padded = np.full(len(data), np.nan)
padded[pad_width:pad_width+len(smoothed)] = smoothed
return padded
# 性能对比测试
import time
# 生成测试数据
np.random.seed(42)
test_data = np.random.randn(10000) + 0.1 * np.sin(np.linspace(0, 20*np.pi, 10000))
# 方法1:使用scipy的savgol_filter(传统方法)
from scipy.signal import savgol_filter
start = time.time()
result_scipy = savgol_filter(test_data, window_length=21, polyorder=3)
time_scipy = time.time() - start
# 方法2:使用我们的sliding_window_view实现
start = time.time()
result_custom = batch_sg_filter(test_data, window_length=21, polyorder=3)
time_custom = time.time() - start
print(f"Scipy方法耗时: {time_scipy:.4f}秒")
print(f"自定义方法耗时: {time_custom:.4f}秒")
print(f"速度提升: {time_scipy/time_custom:.2f}倍")
# 验证结果一致性(忽略边界)
valid_idx = 10:-10
mse = np.mean((result_scipy[valid_idx] - result_custom[valid_idx])**2)
print(f"均方误差: {mse:.6f} (应接近0)")
```
`sliding_window_view` 的魔力在于它通过**跨步(striding)** 技术创建数组视图,而不是复制数据。这意味着即使处理大型数组,内存开销也极小。对于长度为N、窗口长度为W的数据,传统方法需要创建N-W+1个窗口,每个窗口复制W个元素,总内存开销为O((N-W+1)×W)。而`sliding_window_view`只需要原始数据加上一个视图对象,内存开销为O(1)。
在实际的金融数据处理中,我经常用这种方法处理分钟级或秒级的K线数据。下面是一个更实用的例子,展示如何结合pandas处理金融时间序列:
```python
import pandas as pd
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
def realtime_sg_for_financial(data_series, window_length=15, polyorder=3):
"""
为金融时间序列设计的实时SG滤波
参数:
data_series: pandas Series,索引为时间戳
window_length: 窗口长度(建议使用奇数)
polyorder: 多项式阶数
返回:
smoothed_series: 平滑后的Series,与输入同索引
"""
if len(data_series) < window_length:
# 数据不足,返回全NaN
return pd.Series(np.nan, index=data_series.index)
# 提取数值
values = data_series.values
# 预计算系数
coeffs = compute_sg_coefficients(window_length, polyorder)
# 使用sliding_window_view
windows = sliding_window_view(values, window_length)
# 计算平滑值(每个窗口的中心点)
smoothed_center = np.dot(windows, coeffs)
# 创建完整长度的数组,边界用NaN填充
half_window = window_length // 2
smoothed_full = np.full_like(values, np.nan, dtype=float)
smoothed_full[half_window:half_window+len(smoothed_center)] = smoothed_center
return pd.Series(smoothed_full, index=data_series.index)
# 模拟金融价格数据
dates = pd.date_range('2024-01-01 09:30', periods=1000, freq='1min')
prices = 100 + np.cumsum(np.random.randn(1000) * 0.1) # 随机游走
prices += 5 * np.sin(np.arange(1000) * 2 * np.pi / 60) # 添加日内周期
price_series = pd.Series(prices, index=dates)
# 应用实时SG滤波
smoothed_prices = realtime_sg_for_financial(price_series, window_length=31, polyorder=4)
# 计算买卖信号(简单示例:价格上穿平滑线时买入)
buy_signals = (price_series > smoothed_prices) & (price_series.shift(1) <= smoothed_prices.shift(1))
sell_signals = (price_series < smoothed_prices) & (price_series.shift(1) >= smoothed_prices.shift(1))
print(f"检测到买入信号数量: {buy_signals.sum()}")
print(f"检测到卖出信号数量: {sell_signals.sum()}")
```
这种方法的优势在于**处理速度**和**内存效率**。在实盘交易系统中,我经常需要同时监控数百个标的,每个标的都有高频数据流。使用`sliding_window_view`可以大幅降低计算负担,让系统有更多资源处理风控和订单执行。
## 4. TCP传输与延迟优化实战技巧
在分布式系统中,实时数据往往通过TCP连接传输。生物信号采集设备、金融市场数据馈送、物联网传感器网络——这些场景都需要将数据从采集端实时传输到处理端。SG滤波虽然计算高效,但网络延迟可能成为瓶颈。
我在一个脑电图实时监测项目中遇到过这个问题:采集设备通过Wi-Fi发送数据,网络抖动导致数据包到达时间不均匀,直接应用SG滤波会产生伪影。解决方案是**在接收端实现一个带有时戳校正的缓冲队列**。
```python
import socket
import threading
import time
import numpy as np
from collections import deque
from datetime import datetime
class BufferedRealTimeSGFilter:
"""
带缓冲的实时SG滤波器,处理网络传输的不均匀延迟
"""
def __init__(self, window_length=11, polyorder=3, max_buffer_size=1000):
self.window_length = window_length
self.polyorder = polyorder
self.half_window = window_length // 2
# 预计算系数
self.coefficients = compute_sg_coefficients(window_length, polyorder)
# 数据缓冲区(带时间戳)
self.data_buffer = deque(maxlen=max_buffer_size)
self.time_buffer = deque(maxlen=max_buffer_size)
# 输出缓冲区
self.output_buffer = deque(maxlen=100)
# 时间同步相关
self.clock_offset = 0
self.clock_drift = 0
self.last_sync_time = None
# 统计信息
self.stats = {
'packets_received': 0,
'packets_dropped': 0,
'avg_latency': 0,
'max_jitter': 0
}
def receive_data(self, data_packet):
"""
接收数据包,包含值和采集时间戳
参数:
data_packet: 字典,包含'value'和'timestamp'键
"""
value = data_packet['value']
timestamp = data_packet['timestamp']
# 时间同步校正
if self.last_sync_time is not None:
current_time = time.time()
expected_time = timestamp + self.clock_offset + self.clock_drift * (current_time - self.last_sync_time)
time_error = current_time - expected_time
# 如果时间误差太大,丢弃数据包
if abs(time_error) > 0.1: # 100ms阈值
self.stats['packets_dropped'] += 1
return
# 添加到缓冲区
self.data_buffer.append(value)
self.time_buffer.append(timestamp)
self.stats['packets_received'] += 1
# 如果缓冲区足够满,尝试处理
if len(self.data_buffer) >= self.window_length:
self._process_window()
def _process_window(self):
"""处理一个完整窗口的数据"""
# 获取最近window_length个数据点
recent_data = list(self.data_buffer)[-self.window_length:]
recent_times = list(self.time_buffer)[-self.window_length:]
# 检查时间连续性
time_diffs = np.diff(recent_times)
avg_interval = np.mean(time_diffs)
max_jitter = np.max(np.abs(time_diffs - avg_interval))
self.stats['max_jitter'] = max(max_jitter, self.stats['max_jitter'])
self.stats['avg_latency'] = (self.stats['avg_latency'] * 0.9 +
(time.time() - recent_times[-1]) * 0.1)
# 如果抖动太大,可以跳过这个窗口
if max_jitter > avg_interval * 0.5: # 抖动超过平均间隔的50%
return
# 应用SG滤波
smoothed = np.dot(recent_data, self.coefficients)
# 使用窗口中心点的时间戳
center_time = recent_times[self.half_window]
# 添加到输出缓冲区
self.output_buffer.append({
'timestamp': center_time,
'value': smoothed,
'original': recent_data[self.half_window]
})
def get_latest_smoothed(self):
"""获取最新的平滑结果"""
if self.output_buffer:
return self.output_buffer[-1]
return None
def calibrate_clock(self, reference_timestamps):
"""
校准时钟偏移和漂移
参数:
reference_timestamps: 参考时间戳列表
"""
if len(reference_timestamps) < 2:
return
local_times = list(self.time_buffer)[-len(reference_timestamps):]
# 简单线性回归估计时钟偏移和漂移
x = np.array(local_times)
y = np.array(reference_timestamps)
# y = offset + drift * x
A = np.vstack([x, np.ones(len(x))]).T
drift, offset = np.linalg.lstsq(A, y, rcond=None)[0]
self.clock_offset = offset
self.clock_drift = drift - 1 # 漂移是斜率减1
self.last_sync_time = time.time()
# TCP服务器端实现示例
class SGFilterTCPServer:
"""TCP服务器,接收数据并实时SG滤波"""
def __init__(self, host='0.0.0.0', port=9999):
self.host = host
self.port = port
self.filter = BufferedRealTimeSGFilter(window_length=15, polyorder=3)
self.running = False
def start(self):
"""启动服务器"""
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server_socket.bind((self.host, self.port))
self.server_socket.listen(5)
print(f"服务器启动在 {self.host}:{self.port}")
self.running = True
# 启动处理线程
self.process_thread = threading.Thread(target=self._process_loop)
self.process_thread.start()
# 接受连接
while self.running:
client_socket, addr = self.server_socket.accept()
print(f"接收到来自 {addr} 的连接")
# 为每个客户端创建处理线程
client_thread = threading.Thread(
target=self._handle_client,
args=(client_socket, addr)
)
client_thread.start()
def _handle_client(self, client_socket, addr):
"""处理客户端连接"""
buffer = b''
try:
while self.running:
data = client_socket.recv(1024)
if not data:
break
buffer += data
# 解析数据包(假设每个数据包以换行符结束)
while b'\n' in buffer:
packet, buffer = buffer.split(b'\n', 1)
self._process_packet(packet.decode('utf-8').strip())
except Exception as e:
print(f"处理客户端 {addr} 时出错: {e}")
finally:
client_socket.close()
def _process_packet(self, packet_str):
"""处理数据包"""
try:
# 解析JSON格式的数据包
import json
data = json.loads(packet_str)
# 添加到滤波器
self.filter.receive_data(data)
# 获取最新结果
latest = self.filter.get_latest_smoothed()
if latest:
# 这里可以发送到其他系统或存储
print(f"时间: {latest['timestamp']:.3f}, "
f"原始: {latest['original']:.4f}, "
f"平滑: {latest['value']:.4f}")
# 每100个包打印一次统计信息
if self.filter.stats['packets_received'] % 100 == 0:
stats = self.filter.stats
print(f"统计: 接收{stats['packets_received']}包, "
f"丢弃{stats['packets_dropped']}包, "
f"平均延迟{stats['avg_latency']*1000:.1f}ms, "
f"最大抖动{stats['max_jitter']*1000:.1f}ms")
except json.JSONDecodeError:
print(f"无法解析数据包: {packet_str}")
def _process_loop(self):
"""后台处理循环"""
while self.running:
# 这里可以添加定期任务,比如保存数据到数据库
time.sleep(1)
def stop(self):
"""停止服务器"""
self.running = False
if hasattr(self, 'server_socket'):
self.server_socket.close()
# 客户端发送示例
def send_sensor_data(host='localhost', port=9999, duration=60):
"""模拟传感器数据发送"""
import json
import random
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, port))
start_time = time.time()
packet_count = 0
try:
while time.time() - start_time < duration:
# 生成模拟数据:正弦波加噪声
t = time.time() - start_time
value = 10 * np.sin(2 * np.pi * 0.5 * t) + random.gauss(0, 1)
# 添加随机延迟模拟网络抖动
time.sleep(random.uniform(0.001, 0.02))
# 创建数据包
packet = {
'value': value,
'timestamp': time.time(),
'sensor_id': 'sensor_001'
}
# 发送数据
sock.sendall((json.dumps(packet) + '\n').encode('utf-8'))
packet_count += 1
if packet_count % 50 == 0:
print(f"已发送 {packet_count} 个数据包")
finally:
sock.close()
print(f"总共发送 {packet_count} 个数据包")
# 使用示例(在实际应用中需要分别运行)
# 服务器端: server = SGFilterTCPServer(); server.start()
# 客户端: send_sensor_data()
```
这个实现有几个关键优化点:
1. **时间戳校正**:处理网络延迟和时钟不同步
2. **抖动缓冲**:容忍网络抖动,避免因个别延迟包影响整体处理
3. **统计监控**:实时监控数据质量,便于调试和优化
4. **线程安全设计**:适合多客户端并发场景
在金融交易系统中,我还会添加**优先级队列**和**紧急数据通道**。市场快照数据(如order book更新)需要最高优先级,而历史数据可以稍后处理。这可以通过在`BufferedRealTimeSGFilter`中添加多个缓冲区队列来实现。
## 5. Jupyter Notebook实时可视化调试技巧
在开发实时信号处理系统时,可视化调试是不可或缺的。Jupyter Notebook提供了交互式环境,但传统的静态图表无法展示实时数据流。下面我分享一套在Jupyter中实现**实时可视化调试**的完整方案。
首先,我们需要一个能够实时更新的图表。`matplotlib`的动画功能可以做到,但在Jupyter中直接使用`FuncAnimation`可能会遇到性能问题。我更喜欢使用`ipywidgets`结合`matplotlib`的交互模式。
```python
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import ipywidgets as widgets
from threading import Thread, Event
import time
class RealTimeSGVisualizer:
"""
Jupyter Notebook中的实时SG滤波可视化工具
"""
def __init__(self, window_length=21, polyorder=3, buffer_size=500):
self.window_length = window_length
self.polyorder = polyorder
self.buffer_size = buffer_size
# 初始化滤波器
self.filter = RealTimeSGFilter(window_length, polyorder)
# 数据缓冲区
self.raw_data = np.zeros(buffer_size)
self.smoothed_data = np.full(buffer_size, np.nan)
self.time_stamps = np.arange(buffer_size)
# 可视化设置
self.fig, (self.ax1, self.ax2) = plt.subplots(2, 1, figsize=(12, 8))
self.fig.suptitle('实时SG滤波监控', fontsize=14, fontweight='bold')
# 原始数据线
self.raw_line, = self.ax1.plot([], [], 'b-', alpha=0.7, linewidth=1, label='原始数据')
self.smooth_line, = self.ax1.plot([], [], 'r-', linewidth=2, label='SG滤波后')
# 误差区域
self.error_fill = self.ax1.fill_between([], [], [], alpha=0.3, color='gray', label='误差带')
# 频谱图(第二个子图)
self.spectrum_line, = self.ax2.plot([], [], 'g-', linewidth=1.5, label='频谱')
# 设置图表属性
self.ax1.set_xlabel('采样点')
self.ax1.set_ylabel('幅值')
self.ax1.set_title('时域信号')
self.ax1.legend(loc='upper right')
self.ax1.grid(True, alpha=0.3)
self.ax2.set_xlabel('频率 (Hz)')
self.ax2.set_ylabel('功率')
self.ax2.set_title('频谱分析')
self.ax2.legend(loc='upper right')
self.ax2.grid(True, alpha=0.3)
# 控制部件
self.control_panel = self._create_control_panel()
# 线程控制
self.running = False
self.update_event = Event()
self.data_lock = Event()
# 统计信息
self.stats = {
'update_count': 0,
'avg_update_time': 0,
'data_rate': 0
}
def _create_control_panel(self):
"""创建控制面板"""
# 参数控制滑块
window_slider = widgets.IntSlider(
value=self.window_length,
min=5,
max=51,
step=2,
description='窗口长度:',
continuous_update=False
)
polyorder_slider = widgets.IntSlider(
value=self.polyorder,
min=1,
max=10,
step=1,
description='多项式阶数:',
continuous_update=False
)
# 按钮
start_button = widgets.Button(description='开始', button_style='success')
stop_button = widgets.Button(description='停止', button_style='danger')
reset_button = widgets.Button(description='重置', button_style='info')
# 数据显示
stats_text = widgets.HTML(value="<b>统计信息:</b><br>更新次数: 0")
# 回调函数
def on_window_change(change):
self.window_length = change['new']
self.filter = RealTimeSGFilter(self.window_length, self.polyorder)
def on_polyorder_change(change):
self.polyorder = change['new']
self.filter = RealTimeSGFilter(self.window_length, self.polyorder)
def on_start_click(b):
self.start()
def on_stop_click(b):
self.stop()
def on_reset_click(b):
self.reset()
# 绑定事件
window_slider.observe(on_window_change, names='value')
polyorder_slider.observe(on_polyorder_change, names='value')
start_button.on_click(on_start_click)
stop_button.on_click(on_stop_click)
reset_button.on_click(on_reset_click)
# 布局
controls = widgets.VBox([
widgets.HBox([window_slider, polyorder_slider]),
widgets.HBox([start_button, stop_button, reset_button]),
stats_text
])
return controls
def start(self):
"""启动实时更新"""
if not self.running:
self.running = True
self.update_event.set()
# 启动数据生成线程(模拟真实数据源)
self.data_thread = Thread(target=self._generate_data)
self.data_thread.start()
# 启动更新线程
self.update_thread = Thread(target=self._update_loop)
self.update_thread.start()
def stop(self):
"""停止实时更新"""
self.running = False
self.update_event.clear()
def reset(self):
"""重置数据和滤波器"""
self.raw_data.fill(0)
self.smoothed_data.fill(np.nan)
self.filter.reset()
self.stats = {'update_count': 0, 'avg_update_time': 0, 'data_rate': 0}
self._update_plot()
def _generate_data(self):
"""生成模拟数据(在实际应用中替换为真实数据源)"""
t = 0
dt = 0.01 # 10ms采样间隔
while self.running:
# 生成信号:多频正弦波加噪声
signal = (
2.0 * np.sin(2 * np.pi * 1.0 * t) + # 1Hz基波
0.5 * np.sin(2 * np.pi * 5.0 * t) + # 5Hz谐波
0.2 * np.sin(2 * np.pi * 15.0 * t) + # 15Hz谐波
0.1 * np.random.randn() # 高斯噪声
)
# 偶尔添加脉冲干扰
if np.random.random() < 0.01: # 1%概率
signal += 3.0 * np.random.randn()
# 更新数据缓冲区
self.raw_data = np.roll(self.raw_data, -1)
self.raw_data[-1] = signal
# 应用SG滤波
smoothed = self.filter.update(signal)
if smoothed is not None:
self.smoothed_data = np.roll(self.smoothed_data, -1)
self.smoothed_data[-1] = smoothed
t += dt
time.sleep(dt) # 模拟真实采样间隔
def _update_plot(self):
"""更新图表"""
# 清除之前的输出
clear_output(wait=True)
# 更新数据
valid_mask = ~np.isnan(self.smoothed_data)
valid_indices = np.where(valid_mask)[0]
if len(valid_indices) > 0:
# 更新时域图
self.raw_line.set_data(self.time_stamps, self.raw_data)
self.smooth_line.set_data(self.time_stamps[valid_mask], self.smoothed_data[valid_mask])
# 更新误差带
self.ax1.collections.clear() # 清除之前的fill_between
if len(valid_indices) > 0:
self.ax1.fill_between(
self.time_stamps[valid_mask],
self.raw_data[valid_mask],
self.smoothed_data[valid_mask],
alpha=0.3, color='gray', label='误差带'
)
# 更新频谱图
if len(valid_indices) >= 64: # 至少需要一定数量的点做FFT
# 计算原始数据的频谱
fft_raw = np.fft.fft(self.raw_data[valid_indices])
freq = np.fft.fftfreq(len(valid_indices), d=0.01) # 假设采样间隔0.01s
# 只取正频率
pos_mask = freq > 0
freq_pos = freq[pos_mask]
power_raw = np.abs(fft_raw[pos_mask])**2
# 计算平滑数据的频谱
fft_smooth = np.fft.fft(self.smoothed_data[valid_indices])
power_smooth = np.abs(fft_smooth[pos_mask])**2
self.spectrum_line.set_data(freq_pos, power_smooth)
self.ax2.set_xlim(0, 20) # 显示0-20Hz
self.ax2.set_ylim(0, max(power_raw.max(), power_smooth.max()) * 1.1)
# 自动调整坐标轴
self.ax1.relim()
self.ax1.autoscale_view()
# 更新统计信息
self.stats['update_count'] += 1
update_text = f"""
<b>统计信息:</b><br>
更新次数: {self.stats['update_count']}<br>
数据点数: {len(valid_indices)}<br>
窗口长度: {self.window_length}<br>
多项式阶数: {self.polyorder}<br>
延迟: {self.filter.half_window}个采样点
"""
if hasattr(self, 'control_panel'):
# 更新控制面板中的统计信息
for child in self.control_panel.children:
if isinstance(child, widgets.HTML):
child.value = update_text
# 显示图表和控制面板
display(self.fig)
if hasattr(self, 'control_panel'):
display(self.control_panel)
def _update_loop(self):
"""更新循环"""
update_interval = 0.1 # 100ms更新一次图表
while self.running:
start_time = time.time()
# 更新图表
self._update_plot()
# 计算更新时间
update_time = time.time() - start_time
self.stats['avg_update_time'] = (
self.stats['avg_update_time'] * 0.9 + update_time * 0.1
)
# 等待下一个更新周期
time.sleep(max(0, update_interval - update_time))
# 在Jupyter中使用
# visualizer = RealTimeSGVisualizer(window_length=21, polyorder=3)
# display(visualizer.control_panel) # 显示控制面板
# visualizer.start() # 开始实时可视化
```
这个可视化工具提供了几个实用功能:
1. **实时信号显示**:同时显示原始信号和SG滤波后的信号
2. **误差可视化**:用填充区域显示滤波误差
3. **频谱分析**:显示信号的频率成分,帮助理解滤波效果
4. **交互控制**:可以实时调整窗口长度和多项式阶数
5. **统计信息**:显示处理状态和性能指标
在实际的生物信号处理项目中,我经常用这个工具调试滤波器参数。比如处理心电图时,需要保留QRS波群(心率相关,约5-15Hz)而抑制肌电噪声(30-300Hz)。通过实时调整SG滤波参数,可以立即看到滤波效果,找到最佳参数组合。
对于金融数据,我还会添加**技术指标叠加**功能。比如在价格图表上叠加SG滤波结果,同时显示移动平均线、布林带等,帮助判断滤波器的有效性。
```python
def add_technical_indicators(visualizer):
"""为金融数据可视化添加技术指标"""
# 计算移动平均线
def calculate_ema(data, span=20):
"""指数移动平均"""
alpha = 2 / (span + 1)
ema = np.zeros_like(data)
ema[0] = data[0]
for i in range(1, len(data)):
ema[i] = alpha * data[i] + (1 - alpha) * ema[i-1]
return ema
# 计算布林带
def calculate_bollinger_bands(data, window=20, num_std=2):
"""布林带"""
if len(data) < window:
return np.full_like(data, np.nan), np.full_like(data, np.nan), np.full_like(data, np.nan)
sma = np.convolve(data, np.ones(window)/window, mode='valid')
std = np.zeros_like(data)
std[window-1:] = [np.std(data[i-window+1:i+1]) for i in range(window-1, len(data))]
upper = sma + num_std * std[window-1:]
lower = sma - num_std * std[window-1:]
# 填充NaN使长度一致
sma_full = np.full_like(data, np.nan)
upper_full = np.full_like(data, np.nan)
lower_full = np.full_like(data, np.nan)
sma_full[window-1:] = sma
upper_full[window-1:] = upper
lower_full[window-1:] = lower
return sma_full, upper_full, lower_full
# 在可视化中添加技术指标线
valid_mask = ~np.isnan(visualizer.smoothed_data)
if np.sum(valid_mask) > 20:
price_data = visualizer.smoothed_data[valid_mask]
# 计算EMA
ema20 = calculate_ema(price_data, 20)
ema50 = calculate_ema(price_data, 50)
# 计算布林带
sma, upper, lower = calculate_bollinger_bands(price_data, 20, 2)
# 添加到图表
indices = np.where(valid_mask)[0]
visualizer.ax1.plot(indices, ema20, 'g--', alpha=0.7, linewidth=1, label='EMA20')
visualizer.ax1.plot(indices, ema50, 'm--', alpha=0.7, linewidth=1, label='EMA50')
# 布林带填充
visualizer.ax1.fill_between(indices[19:], lower[19:], upper[19:],
alpha=0.2, color='orange', label='布林带')
visualizer.ax1.legend(loc='upper left', fontsize=8)
```
这种实时可视化方法在策略开发中特别有用。你可以立即看到参数调整对信号质量的影响,快速迭代找到最优配置。我经常在开发新的量化因子时使用这种方法,比传统的回测-优化循环要高效得多。
## 6. 高级应用:多尺度SG滤波与自适应参数
基本的SG滤波使用固定的窗口长度和多项式阶数,但在实际应用中,信号特性可能随时间变化。金融市场的波动率在开盘、午盘和收盘时段差异很大,生物信号在不同生理状态下也有不同特征。这就需要**自适应参数**的SG滤波。
我开发过一个自适应SG滤波算法,根据信号的局部特性动态调整参数。核心思想是监测信号的**局部信噪比**和**频率成分**,然后调整滤波参数。
```python
class AdaptiveSGFilter:
"""自适应参数SG滤波器"""
def __init__(self,
min_window=5,
max_window=51,
min_polyorder=1,
max_polyorder=5,
adaptation_rate=0.1):
"""
初始化自适应SG滤波器
参数:
min_window: 最小窗口长度(奇数)
max_window: 最大窗口长度(奇数)
min_polyorder: 最小多项式阶数
max_polyorder: 最大多项式阶数
adaptation_rate: 参数适应速率(0-1)
"""
self.min_window = min_window if min_window % 2 else min_window + 1
self.max_window = max_window if max_window % 2 else max_window - 1
self.min_polyorder = min_polyorder
self.max_polyorder = max_polyorder
self.adaptation_rate = adaptation_rate
# 当前参数
self.window_length = (self.min_window + self.max_window) // 2
self.polyorder = (self.min_polyorder + self.max_polyorder) // 2
# 初始化滤波器
self.filter = RealTimeSGFilter(self.window_length, self.polyorder)
# 自适应参数
self.snr_history = [] # 信噪比历史
self.freq_history = [] # 主导频率历史
self.param_history = [] # 参数历史
# 统计缓冲区
self.signal_buffer = []
self.residual_buffer = []
self.buffer_size = 100
def update(self, new_value):
"""
更新滤波器并自适应调整参数
参数:
new_value: 新的数据点
返回:
smoothed_value: 平滑后的值
"""
# 使用当前参数滤波
smoothed = self.filter.update(new_value)
if smoothed is not None:
# 计算残差(噪声估计)
residual = new_value - smoothed
# 更新缓冲区
self.signal_buffer.append(new_value)
self.residual_buffer.append(residual)
if len(self.signal_buffer) > self.buffer_size:
self.signal_buffer.pop(0)
self.residual_buffer.pop(0)
# 当缓冲区足够大时,评估信号特性并调整参数
if len(self.signal_buffer) >= 50:
self._adapt_parameters()
return smoothed
def _adapt_parameters(self):
"""根据信号特性自适应调整参数"""
if len(self.signal_buffer) < 20:
return
signal_array = np.array(self.signal_buffer)
residual_array = np.array(self.residual_buffer)
# 计算局部信噪比
signal_power = np.var(signal_array)
noise_power = np.var(residual_array)
if noise_power > 0:
snr = 10 * np.log10(signal_power / noise_power)
else:
snr = 100 # 无穷大信噪比
self.snr_history.append(snr)
if len(self.snr_history) > 50:
self.snr_history.pop(0)
# 计算信号的主导频率
if len(signal_array) >= 32:
fft_result = np.fft.fft(signal_array - np.mean(signal_array))
freqs = np.fft.fftfreq(len(signal_array))
power = np.abs(fft_result)**2
# 找到主导频率(排除直流分量)
dominant_idx = np.argmax(power[1:len(power)//2]) + 1
dominant_freq = freqs[dominant_idx]
self.freq_history.append(abs(dominant_freq))
if len(self.freq_history) > 50:
self.freq_history.pop(0)
# 基于信号特性调整参数
self._adjust_by_snr(snr)
self._adjust_by_frequency()
# 记录参数变化
self.param_history.append({
'window': self.window_length,
'polyorder': self.polyorder,
'snr': snr,
'timestamp': time.time()
})
def _adjust_by_snr(self, snr):
"""根据信噪比调整参数"""
# 信噪比低 -> 需要更强平滑 -> 增大窗口,降低阶数
# 信噪比高 -> 可以保留更多细节 -> 减小窗口,增加阶数
target_window = self.window_length
target_polyorder = self.polyorder
if snr < 10: # 低信噪比
target_window = min(self.max_window,
int(self.window_length * (1 + self.adaptation_rate)))
target_polyorder = max(self.min_polyorder,
int(self.polyorder * (1 - self.adaptation_rate)))
elif snr > 30: # 高信噪比
target_window = max(self.min_window,
int(self.window_length * (1 - self.adaptation_rate)))
target_polyorder = min(self.max_polyorder,
int(self.polyorder * (1 + self.adaptation_rate)))
# 确保窗口为奇数
if target_window % 2 == 0:
target_window += 1 if target_window < self.window_length else -1
# 确保polyorder < window_length
target_polyorder = min(target_polyorder, target_window - 1)
# 更新参数(如果需要)
if (target_window != self.window_length or
target_polyorder != self.polyorder):
self.window_length = target_window
self.polyorder = target_polyorder
self.filter = RealTimeSGFilter(self.window_length, self.polyorder)
def _adjust_by_frequency(self):
"""根据频率特性调整参数"""
if len(self.freq_history) < 5:
return
avg_freq = np.mean(self.freq_history[-5:])
# 高频信号 -> 需要较小窗口以保留细节
# 低频信号 -> 可以使用较大窗口进行平滑
if avg_freq > 0.1: # 相对高频
target_window = max(self.min_window,
min(15, self.window_length))
elif avg_freq < 0.01: # 相对低频
target_window = min(self.max_window,
max(31, self.window_length))
# 确保窗口为奇数
if target_window % 2 == 0:
target_window += 1
# 更新参数(如果需要)
if target_window != self.window_length:
self.window_length = target_window
self.polyorder = min(self.polyorder, self.window_length - 1)
self.filter = RealTimeSGFilter(self.window_length, self.polyorder)
def get_current_params(self):
"""获取当前参数"""
return {
'window_length': self.window_length,
'polyorder': self.polyorder,
'avg_snr': np.mean(self.snr_history) if self.snr_history else 0,
'avg_freq': np.mean(self.freq_history) if self.freq_history else 0
}
# 测试自适应滤波器
def test_adaptive_filter():
"""测试自适应SG滤波器在不同信号条件下的表现"""
# 生成测试信号:变信噪比和变频率
np.random.seed(42)
n_samples = 1000
t = np.linspace(0, 10, n_samples)
# 变频率信号
freq = 0.5 + 0.3 * np.sin(2 * np.pi * 0.1 * t) # 0.2-0.8Hz变化
signal = np.sin(2 * np.pi * freq * t)
# 变信噪比噪声
noise_level = 0.1 + 0.09 * np.sin(2 * np.pi * 0.05 * t) # 0.01-0.19变化
noise = noise_level[:, None] * np.random.randn(n_samples)
# 添加脉冲噪声
impulse_noise = np.zeros(n_samples)
impulse_indices = np.random.choice(n_samples, size=20, replace=False)
impulse_noise[impulse_indices] = np.random.randn(20) * 2
noisy_signal = signal + noise.flatten() + impulse_noise
# 应用自适应滤波
adaptive_filter = AdaptiveSGFilter(
min_window=5,
max_window=51,
min_polyorder=1,
max_polyorder=5,
adaptation_rate=0.05
)
# 应用固定参数滤波作为对比
fixed_filter1 = RealTimeSGFilter(window_length=11, polyorder=3) # 适中参数
fixed_filter2 = RealTimeSGFilter(window_length=5, polyorder=2) # 小窗口
fixed_filter3 = RealTimeSGFilter(window_length=31, polyorder=4) # 大窗口
# 处理信号
adaptive_output = []
fixed_output1 = []
fixed_output2 = []
fixed_output3 = []
params_history = []
for i, value in enumerate(noisy_signal):
# 自适应滤波
smoothed_adaptive = adaptive_filter.update(value)
if smoothed_adaptive is not None:
adaptive_output.append(smoothed_adaptive)
if len(adaptive_output) % 50 == 0:
params_history.append(adaptive_filter.get_current_params())
# 固定参数滤波
smoothed_fixed1 = fixed_filter1.update(value)
smoothed_fixed2 = fixed_filter2.update(value)
smoothed_fixed3 = fixed_filter3.update(value)
if smoothed_fixed1 is not None:
fixed_output1.append(smoothed_fixed1)
fixed_output2.append(smoothed_fixed2)
fixed_output3.append(smoothed_fixed3)
# 计算性能指标
def calculate_metrics(original, filtered):
"""计算滤波性能指标"""
# 对齐长度
min_len = min(len(original), len(filtered))
original = original[-min_len:]
filtered = filtered[-min_len:]
# 均方误差
mse = np.mean((original - filtered) ** 2)
# 信噪比改善
original_noise = original - signal[-min_len:]
filtered_noise = filtered - signal[-min_len:]
snr_original = 10 * np.log10(np.var(signal[-min_len:]) / np.var(original_noise))
snr_filtered = 10 * np.log10(np.var(signal[-min_len:]) / np.var(filtered_noise))
snr_improvement = snr_filtered - snr_original
# 计算相关系数
correlation = np.corrcoef(original, filtered)[0, 1]
return {
'MSE': mse,
'SNR_improvement': snr_improvement,
'Correlation': correlation
}
# 评估各滤波器性能
metrics_adaptive = calculate_metrics(signal, adaptive_output)
metrics_fixed1 = calculate_metrics(signal, fixed_output1)
metrics_fixed2 = calculate_metrics(signal, fixed_output2)
metrics_fixed3 = calculate_metrics(signal, fixed_output3)
# 打印结果
print("自适应SG滤波器性能:")
for key, value in metrics_adaptive.items():
print(f" {key}: {value:.4f}")
print("\n固定参数滤波器性能对比:")
results = {
'适中参数(11,3)': metrics_fixed1,
'小窗口(5,2)': metrics_fixed2,
'大窗口(31,4)': metrics_fixed3
}
for name, metrics in results.items():
print(f"\n{name}:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
# 绘制参数变化图
if params_history:
windows = [p['window_length'] for p in params_history]
polyorders = [p['polyorder'] for p in params_history]
snrs = [p['avg_snr'] for p in params_history]
fig, axes = plt.subplots(3, 1, figsize=(12, 8))
axes[0].plot(windows, 'b-', linewidth=2)
axes[0].set_ylabel('窗口长度')
axes[0].grid(True, alpha=0.3)
axes[0].set_title('自适应SG滤波器参数变化')
axes[1].plot(polyorders, 'r-', linewidth=2)
axes[1].set_ylabel('多项式阶数')
axes[1].grid(True, alpha=0.3)
axes[2].plot(snrs, 'g-', linewidth=2)
axes[2].set_ylabel('平均信噪比(dB)')
axes[2].set_xlabel('更新次数')
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
return adaptive_output, fixed_output1, fixed_output2, fixed_output3
# 运行测试
adaptive_result, fixed1, fixed2, fixed3 = test_adaptive_filter()
```
自适应SG滤波器在实际应用中表现出色,特别是在处理非平稳信号时。在金融数据中,市场在不同时段(如开盘、午盘、收盘)具有不同的波动特性;在生物信号中,不同生理状态(如静息、运动、睡眠)下信号特征也不同。固定参数的滤波器要么过度平滑丢失细节,要么平滑不足残留噪声,而自适应滤波器能够根据信号特性动态调整,达到更好的平衡。
我在一个心电图实时监测项目中应用了这个方法。静息状态下心率较稳定,信号信噪比较高,滤波器自动选择较小的窗口和较高的多项式阶数以保留更多细节;运动状态下信号噪声大,滤波器自动增大窗口、降低阶数以增强平滑效果。这种自适应能力显著提高了心率检测的准确性。
## 7. 性能优化与生产环境部署
将实时SG滤波算法部署到生产环境时,性能优化至关重要。在金融交易系统中,毫秒级的延迟差异可能意味着盈利与亏损的区别;在医疗监测设备中,实时性直接关系到患者安全。
**计算优化**是第一个要考虑的方面。虽然我们之前实现的实时SG滤波器已经是O(n)复杂度,但还有进一步优化的空间:
```python
import numba
from numba import jit
import numpy as np
@jit(nopython=True, cache=True)
def sg_filter_numba(data, coefficients):
"""
使用Numba加速的SG滤波实现
参数:
data: 输入数据数组
coefficients: 预计算的SG系数
返回:
smoothed: 平滑后的数据(边界为NaN)
"""
n = len(data)
w = len(coefficients)
half_w = w // 2
smoothed = np.full(n, np.nan)
for i in range(half_w, n - half_w):
# 手动展开循环以提高性能
result = 0.0
for j in range(w):
result += data[i - half_w + j] * coefficients[j]
smoothed[i] = result
return smoothed
class OptimizedRealTimeSGFilter:
"""经过性能优化的实时SG滤波器"""
def __init__(self, window_length=11, polyorder=3, use_numba=True):
self.window_length = window_length
self.polyorder = polyorder
self.half_window = window_length // 2
self.use_numba = use_numba
# 预计算系数
self.coefficients = compute_sg_coefficients(window_length, polyorder)
# 使用环形缓冲区
self.buffer = np.zeros(window_length)
self.buffer_idx = 0
self.buffer_full = False
# 预分配输出数组
self.output_ready = False
self.delay_counter = 0
# 性能统计
self.processing_times = []
self.max_history = 1000
def update_optimized(self, new_value):
"""优化版的更新函数"""
import time
start_time = time.perf_counter()
# 更新缓冲区
self.buffer[self.buffer_idx] = new_value
self.buffer_idx = (self.buffer_idx + 1) % self.window_length
if not self.buffer_full and self.buffer_idx == 0:
self.buffer_full = True
# 更新延迟计数器
if not self.output_ready:
self.delay_counter += 1
if self.delay_counter > self.half_window:
self.output_ready = True
# 计算输出
result = None
if self.output_ready and self.buffer_full:
if self.use_numba:
# 使用Numba加速版本
temp_buffer = np.zeros(self.window_length)
for i in range(self.window_length):
idx = (self.buffer_idx - self.half_window - 1 + i) % self.window_length
temp_buffer[i] = self.buffer[idx]
result = sg_filter_numba(temp_buffer, self.coefficients)[self.half_window]
else:
# 标准版本
result = 0.0
for i in range(self.window_length):
idx = (self.buffer_idx - self.half_window - 1 + i) % self.window_length
result += self.buffer[idx] * self.coefficients[i]
# 记录处理时间
end_time = time.perf_counter()
self.processing_times.append(end_time - start_time)
if len(self.processing_times) > self.max_history:
self.processing_times.pop(0)
return result
def get_performance_stats(self):
"""获取性能统计"""
if not self.processing_times:
return None
times = np.array(self.processing_times)
return {
'avg_time_us': np.mean(times) * 1e6,
'max_time_us': np.max(times) * 1e6,
'min_time_us': np.min(times) * 1e6,
'std_time_us': np.std(times) * 1e6,
'percentile_95_us': np.percentile(times, 95) * 1e6,
'percentile_99_us': np.percentile(times, 99) * 1e6
}
# 性能对比测试
def performance_comparison():
"""对比不同实现的性能"""
import time
# 生成测试数据
np.random.seed(42)
test_data = np.random.randn(100000) # 10万个数据点
# 测试不同实现
implementations = [
("标准Python循环", RealTimeSGFilter(51, 3)),
("Numba加速", OptimizedRealTimeSGFilter(51, 3, use_numba=True)),
("NumPy点积", OptimizedRealTimeSGFilter(51, 3, use_numba=False))
]
results = {}
for name, filter_obj in implementations:
print(f"\n测试 {name}...")
# 预热(避免JIT编译时间影响)
for i in range(100):
filter_obj.update_optimized(test_data[i % 100])
# 正式测试
start_time = time.perf_counter()
outputs = []
for value in test_data:
result = filter_obj.update_optimized(value)
if result is not None:
outputs.append(result)
end_time = time.perf_counter()
# 计算统计信息
total_time = end_time - start_time
throughput = len(test_data) / total_time
if hasattr(filter_obj, 'get_performance_stats'):
stats = filter_obj.get_performance_stats()
else:
stats = None
results[name] = {
'total_time': total_time,
'throughput': throughput,
'stats': stats,
'output_length': len(outputs)
}
print(f" 总时间: {total_time:.3f}秒")
print(f" 吞吐量: {throughput:.0f} 数据点/秒")
if stats:
print(f" 平均处理时间: {stats['avg_time_us']:.2f}微秒")
print(f" 95%分位数: {stats['percentile_95_us']:.2f}微秒")
print(f" 99%分位数: {stats['percentile_99_us']:.2f}微秒")
# 性能对比表格
print("\n" + "="*60)
print("性能对比总结:")
print("="*60)
print(f"{'实现方式':<20} {'吞吐量(点/秒)':<15} {'平均延迟(μs)':<15} {'P95延迟(μs)':<15}")
print("-"*60)
for name, result in results.items():
throughput = f"{result['throughput']:,.0f}"
if result['stats']:
avg_latency = f"{result['stats']['avg_time_us']:.1f}"
p95_latency = f"{result['stats']['percentile_95_us']:.1f}"
else:
avg_latency = "N/A"
p95_latency = "N/A"
print(f"{name:<20} {throughput:<15} {avg_latency:<15} {p95_latency:<15}")
return results
# 运行性能测试
performance_results = performance_comparison()
```
除了计算优化,**内存管理**也很重要。在嵌入式系统或资源受限的环境中,需要精心设计数据结构:
```python
class MemoryEfficientSGFilter:
"""内存高效的SG滤波器实现"""
def __init__(self, window_length=11, polyorder=3):
self.window_length = window_length
self.polyorder = polyorder
self.half_window = window_length // 2
# 预计算系数(单精度浮点数节省内存)
self.coefficients = compute_sg_coefficients(window_length, polyorder)
self.coefficients = self.coefficients.astype(np.float32)
# 使用固定大小的数组作为环形缓冲区
self.buffer = np.zeros(window_length, dtype=np.float32)
self.buffer_idx = 0
# 使用位运算加速模运算
self.buffer_mask = window_length - 1
if (window_length & (window_length - 1)) != 0:
# 如果window_length不是2的幂,使用取模运算
self._get_index = lambda idx: idx % window_length
else:
# 如果window_length是2的幂,使用位与运算(更快)
self._get_index = lambda idx: idx & self.buffer_mask
# 状态标志
self.initialized = False
self.samples_received = 0
def update_efficient(self, new_value):
"""内存和计算效率都优化的更新函数"""
# 更新缓冲区
self.buffer[self.buffer_idx] = new_value
self.buffer_idx = self._get_index(self.buffer_idx + 1)
self.samples_received += 1
# 检查是否已初始化
if not self.initialized and self.samples_received >= self.window_length:
self.initialized = True
# 如果已初始化,计算输出
if self.initialized:
# 计算卷积
result = 0.0
coeffs = self.coefficients
# 手动展开循环(根据窗口大小选择最佳展开因子)
if self.window_length == 11:
# 针对特定窗口大小的优化
idx0 = self._get_index(self.buffer_idx - 6)
idx1 = self._get_index(self.buffer_idx - 5)
idx2 = self._get_index(self.buffer_idx - 4)
idx3 = self._get_index(self.buffer_idx - 3)
idx4 = self._get_index(self.buffer_idx - 2)
idx5 = self._get_index(self.buffer_idx - 1)
idx6 = self._get_index(self.buffer_idx - 0)
idx7 = self._get_index(self.buffer_idx + 1)
idx8 = self._get_index(self.buffer_idx + 2)
idx9 = self._get_index(self.buffer_idx + 3)
idx10 = self._get_index(self.buffer_idx + 4)
result = (self.buffer[idx0] * coeffs[0] +
self.buffer[idx1] * coeffs[1] +
self.buffer[idx2] * coeffs[2] +
self.buffer[idx3] * coeffs[3] +
self.buffer[idx4] * coeffs[4] +
self.buffer[idx5] * coeffs[5] +
self.buffer[idx6] * coeffs[6] +
self.buffer[idx7] * coeffs[7] +
self.buffer[idx8] * coeffs[8] +
self.buffer[idx9] * coeffs[9] +
self.buffer[idx10] * coeffs[10])
else:
# 通用实现
for i in range(self.window_length):
idx = self._get_index(self.buffer_idx - self.half_window - 1 + i)
result += self.buffer[idx] * coeffs[i]
return result
return None
def get_memory_usage(self):
"""获取内存使用情况(字节)"""
buffer_size = self.buffer.nbytes
coeffs_size = self.coefficients.nbytes
total_size = buffer_size + coeffs_size + 64 # 加上对象开销估计
return {
'buffer': buffer_size,
'coefficients': coeffs_size,
'estimated_total': total_size,
'window_length': self.window_length,
'dtype': self.buffer.dtype
}
# 内存使用对比
def memory_usage_comparison():
"""对比不同实现的内存使用"""
implementations = [
("标准实现", RealTimeSGFilter(51, 3)),
("内存优化", MemoryEfficientSGFilter(51, 3)),
("NumPy实现", OptimizedRealTimeSGFilter(51, 3))
]
print("内存使用对比:")
print("="*50)
for name, filter_obj in implementations:
if hasattr(filter_obj, 'get_memory_usage'):
usage = filter_obj.get_memory_usage()
print(f"\n{name}:")
print(f" 缓冲区: {usage.get('buffer', 'N/A'):,} 字节")
print(f" 系数: {usage.get('coefficients', 'N/A'):,} 字节")
print(f" 估计总量: {usage.get('estimated_total', 'N/A'):,} 字节")
if 'dtype' in usage:
print(f" 数据类型: {usage['dtype']}")
else:
# 估算标准实现的内存使用
import sys
size = sys.getsizeof(filter_obj)
print(f"\n{name}:")