使用内置的 STFT/ISTFT 接口
这种方法利用 torch.stft(内部采用 rfft)和 torch.istft 完成变换,同时借助加窗(例如 Hann 窗)保证帧内加窗并采用重叠相加(常用 50% 重叠)实现完美重构。窗口长度可以灵活设置,例如 64 或 32。
这种方式利用了 PyTorch 内置的 STFT 与 ISTFT 函数,它们内部使用了 rfft/irfft,同时支持加窗并且能够保证重构出的信号长度与输入一致。下面分别给出代码示例及详细注释:
import torch
def process_signal_stft(signal, window_length):
"""
对输入信号进行短时傅里叶变换和逆变换,实现:
1. 对信号分帧,每帧加窗(本例使用 Hann 窗)
2. 对每帧采用 rfft 变换到频域
3. 再采用 istft 将频域转换回时域(重构信号)
要求:输入输出信号长度一致且信号完全重构(不发生改变)。
参数:
signal: 1D 的 torch.Tensor 时域信号,长度为 N
window_length: 加窗长度,例如 64 或 32(通常应小于等于信号长度)
返回:
recovered_signal: 重构后的时域信号,长度与 signal 相同
"""
# 这里采用 50% 重叠,hop_length = window_length // 2
hop_length = window_length // 2
# 创建 Hann 窗
window = torch.hann_window(window_length, dtype=signal.dtype, device=signal.device)
# 使用 torch.stft 进行短时傅里叶变换
# 参数说明:n_fft 设置为 window_length,此时返回结果中频域数据采用的是复数形式 (return_complex=True)
stft_data = torch.stft(signal, n_fft=window_length, hop_length=hop_length, window=window, return_complex=True)
# 使用 torch.istft 进行逆短时傅里叶变换,length 指定重构后信号长度和输入一致
recovered_signal = torch.istft(stft_data, n_fft=window_length, hop_length=hop_length, window=window, length=signal.size(0))
return recovered_signal
# 示例代码
if __name__ == "__main__":
L = 256
# 生成随机时域信号
signal = torch.randn(L)
window_length = 64 # 可以改为 32 等其他值
reconstructed_signal = process_signal_stft(signal, window_length)
# 检查重构精度
print("是否完全重构:", torch.allclose(signal, reconstructed_signal, atol=1e-6))
# 绘制原始信号和重构信号
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
plt.plot(signal.cpu().numpy(), label='Original Signal')
plt.plot(reconstructed_signal.cpu().numpy(), label='Reconstructed Signal')
plt.legend()
plt.show()
# plt.savefig('reconstructed_signal.png')
# plt.close()
说明:
STFT 与 ISTFT 内部都会使用 rfft 和 irfft;同时使用 Hann 窗以及 50% 重叠满足常见的完美重构条件(注意:Hann 窗在 50% 重叠(hop=win_length/2)下的重叠求和恒等于常数,因此在 ISTFT 时自动补偿了窗函数的影响)。
输出信号长度通过参数 length=signal.size(0)
保证与原始信号一致。
解决方法:在原始信号前就 zero-pad
确保 STFT 的帧数足够重构出你想要的长度:
def zero_pad_signal(signal, target_length):
pad_length = target_length - signal.size(0)
if pad_length > 0:
signal = torch.cat([signal, torch.zeros(pad_length, device=signal.device)])
return signal
使用:
signal = torch.randn(N)
output_length = 2 * N
padded_signal = zero_pad_signal(signal, output_length)
reconstructed = process_signal_stft(padded_signal, window_length, output_length)
完整代码:
import torch
def process_signal_stft(signal, window_length, output_length=None):
"""
对输入信号进行短时傅里叶变换和逆变换,可设定输出时域长度。
参数:
signal: 1D torch.Tensor 时域信号,长度为 N
window_length: 窗口大小,例如 64、32
output_length: 指定输出信号长度,例如 N 的 2 倍(如果为 None,则默认为输入长度)
返回:
recovered_signal: 通过 STFT+ISTFT 重建的时域信号,长度为 output_length
"""
hop_length = window_length // 2
window = torch.hann_window(window_length, dtype=signal.dtype, device=signal.device)
def zero_pad_signal(signal, target_length):
pad_length = target_length - signal.size(0)
if pad_length > 0:
signal = torch.cat([signal, torch.zeros(pad_length, device=signal.device)])
return signal
# 如果没有指定输出长度,则默认与输入一致
if output_length is not None:
# output_length = signal.size(0)
signal = zero_pad_signal(signal, output_length)
# STFT:返回频域张量 (freq_bins, frames)
stft_data = torch.stft(signal, n_fft=window_length, hop_length=hop_length, window=window, return_complex=True)
# ISTFT:指定 output_length
recovered_signal = torch.istft(stft_data, n_fft=window_length, hop_length=hop_length,
window=window, length=output_length)
return recovered_signal
# 示例代码
if __name__ == "__main__":
L = 256
signal = torch.randn(L)
window_length = 64
output_length = 2 * L # 输出长度为原始的两倍
reconstructed_signal = process_signal_stft(signal, window_length, output_length)
print("原始长度:", signal.size(0))
print("输出长度:", reconstructed_signal.size(0))
print("输出前半部分是否接近原始信号:", torch.allclose(signal, reconstructed_signal[:L], atol=1e-6))
# 绘制信号
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.plot(signal.cpu().numpy(), label='Original Signal')
plt.plot(reconstructed_signal.cpu().numpy(), label='Reconstructed Signal')
plt.legend()
plt.show()
# plt.savefig('signal.png')
# plt.close()