机器学习笔记:时域和频域变换

加窗操作

使用内置的 STFT/ISTFT 接口

这种方法利用 torch.stft(内部采用 rfft)和 torch.istft 完成变换,同时借助加窗(例如 Hann 窗)保证帧内加窗并采用重叠相加(常用 50% 重叠)实现完美重构。窗口长度可以灵活设置,例如 64 或 32。
这种方式利用了 PyTorch 内置的 STFT 与 ISTFT 函数,它们内部使用了 rfft/irfft,同时支持加窗并且能够保证重构出的信号长度与输入一致。下面分别给出代码示例及详细注释:

import torch


def process_signal_stft(signal, window_length):
    """
    对输入信号进行短时傅里叶变换和逆变换,实现:
      1. 对信号分帧,每帧加窗(本例使用 Hann 窗)
      2. 对每帧采用 rfft 变换到频域
      3. 再采用 istft 将频域转换回时域(重构信号)
    要求:输入输出信号长度一致且信号完全重构(不发生改变)。

    参数:
      signal: 1D 的 torch.Tensor 时域信号,长度为 N
      window_length: 加窗长度,例如 64 或 32(通常应小于等于信号长度)

    返回:
      recovered_signal: 重构后的时域信号,长度与 signal 相同
    """

    # 这里采用 50% 重叠,hop_length = window_length // 2
    hop_length = window_length // 2
    # 创建 Hann 窗
    window = torch.hann_window(window_length, dtype=signal.dtype, device=signal.device)
    # 使用 torch.stft 进行短时傅里叶变换
    # 参数说明:n_fft 设置为 window_length,此时返回结果中频域数据采用的是复数形式 (return_complex=True)
    stft_data = torch.stft(signal, n_fft=window_length, hop_length=hop_length, window=window, return_complex=True)
    # 使用 torch.istft 进行逆短时傅里叶变换,length 指定重构后信号长度和输入一致
    recovered_signal = torch.istft(stft_data, n_fft=window_length, hop_length=hop_length, window=window, length=signal.size(0))
    return recovered_signal


# 示例代码
if __name__ == "__main__":
    L = 256
    # 生成随机时域信号
    signal = torch.randn(L)
    window_length = 64  # 可以改为 32 等其他值
    reconstructed_signal = process_signal_stft(signal, window_length)
    # 检查重构精度
    print("是否完全重构:", torch.allclose(signal, reconstructed_signal, atol=1e-6))

    # 绘制原始信号和重构信号
    import matplotlib.pyplot as plt
    plt.figure(figsize=(12, 6))
    plt.plot(signal.cpu().numpy(), label='Original Signal')
    plt.plot(reconstructed_signal.cpu().numpy(), label='Reconstructed Signal')
    plt.legend()
    plt.show()
    # plt.savefig('reconstructed_signal.png')
    # plt.close()



说明:

  • STFT 与 ISTFT 内部都会使用 rfft 和 irfft;同时使用 Hann 窗以及 50% 重叠满足常见的完美重构条件(注意:Hann 窗在 50% 重叠(hop=win_length/2)下的重叠求和恒等于常数,因此在 ISTFT 时自动补偿了窗函数的影响)。

  • 输出信号长度通过参数 length=signal.size(0) 保证与原始信号一致。

变更输出信号长度

解决方法:在原始信号前就 zero-pad

确保 STFT 的帧数足够重构出你想要的长度:

def zero_pad_signal(signal, target_length):
    pad_length = target_length - signal.size(0)
    if pad_length > 0:
        signal = torch.cat([signal, torch.zeros(pad_length, device=signal.device)])
    return signal

使用:

signal = torch.randn(N)
output_length = 2 * N
padded_signal = zero_pad_signal(signal, output_length)
reconstructed = process_signal_stft(padded_signal, window_length, output_length)

完整代码:

import torch


def process_signal_stft(signal, window_length, output_length=None):
    """
    对输入信号进行短时傅里叶变换和逆变换,可设定输出时域长度。

    参数:
      signal: 1D torch.Tensor 时域信号,长度为 N
      window_length: 窗口大小,例如 64、32
      output_length: 指定输出信号长度,例如 N 的 2 倍(如果为 None,则默认为输入长度)

    返回:
      recovered_signal: 通过 STFT+ISTFT 重建的时域信号,长度为 output_length
    """
    hop_length = window_length // 2
    window = torch.hann_window(window_length, dtype=signal.dtype, device=signal.device)

    def zero_pad_signal(signal, target_length):
        pad_length = target_length - signal.size(0)
        if pad_length > 0:
            signal = torch.cat([signal, torch.zeros(pad_length, device=signal.device)])
        return signal

    # 如果没有指定输出长度,则默认与输入一致
    if output_length is not None:
        # output_length = signal.size(0)
        signal = zero_pad_signal(signal, output_length)

    # STFT:返回频域张量 (freq_bins, frames)
    stft_data = torch.stft(signal, n_fft=window_length, hop_length=hop_length, window=window, return_complex=True)

    # ISTFT:指定 output_length
    recovered_signal = torch.istft(stft_data, n_fft=window_length, hop_length=hop_length,
                                   window=window, length=output_length)
    return recovered_signal


# 示例代码
if __name__ == "__main__":
    L = 256
    signal = torch.randn(L)
    window_length = 64
    output_length = 2 * L  # 输出长度为原始的两倍
    reconstructed_signal = process_signal_stft(signal, window_length, output_length)

    print("原始长度:", signal.size(0))
    print("输出长度:", reconstructed_signal.size(0))
    print("输出前半部分是否接近原始信号:", torch.allclose(signal, reconstructed_signal[:L], atol=1e-6))

    # 绘制信号
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 4))
    plt.plot(signal.cpu().numpy(), label='Original Signal')
    plt.plot(reconstructed_signal.cpu().numpy(), label='Reconstructed Signal')
    plt.legend()
    plt.show()
    # plt.savefig('signal.png')
    # plt.close()

你可能感兴趣的:(机器学习笔记,机器学习,笔记,人工智能)