手写BatchNorm与LayerNorm:从原理到实现

手写BatchNorm与LayerNorm:从原理到实现

大家好,今天我们将手写实现Batch Normalization(BatchNorm,批归一化)Layer Normalization(LayerNorm,层归一化)。通过手写实现,我们可以更深入地理解这两种归一化技术的原理和细节。本文将从原理、代码实现以及对比分析等方面,带大家彻底搞懂BatchNorm和LayerNorm!


1. Batch Normalization(BatchNorm)

1.1 BatchNorm的原理

BatchNorm的核心思想是对每个特征通道在**一个批次(batch)**内进行归一化。具体步骤如下:

  1. 计算批次中每个特征通道的均值 (\mu_B) 和方差 (\sigma_B^2)。
  2. 使用均值和方差对数据进行归一化:
    [
    \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
    ]
  3. 对归一化后的数据进行缩放和平移:
    [
    y_i = \gamma \hat{x}_i + \beta
    ]
    其中,(\gamma) 和 (\beta) 是可学习的参数。
1.2 BatchNorm的手写实现

以下是BatchNorm的Python实现:

import torch

class BatchNorm:
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        # 可学习参数
        self.gamma = torch.ones(num_features)
        self.beta = torch.zeros(num_features)

        # 运行时的均值和方差
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)

    def forward(self, x, training=True):
        if x.dim() != 4:
            raise ValueError("输入数据必须是4D张量:[batch_size, channels, height, width]")

        batch_size, channels, height, width = x.shape

        if training:
            # 计算均值和方差
            mean = x.mean(dim=(0, 2, 3), keepdim=True)  # [1, channels, 1, 1]
            var = x.var(dim=(0, 2, 3), keepdim=True)    # [1, channels, 1, 1]

            # 更新运行时的均值和方差
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.squeeze()
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()
        else:
            mean = self.running_mean.view(1, channels, 1, 1)
            var = self.running_var.view(1, channels, 1, 1)

        # 归一化
        x_hat = (x - mean) / torch.sqrt(var + self.eps)

        # 缩放和平移
        y = self.gamma.view(1, channels, 1, 1) * x_hat + self.beta.view(1, channels, 1, 1)

        return y
1.3 BatchNorm的使用示例
# 定义BatchNorm层
batch_norm = BatchNorm(num_features=64)

# 输入数据
input_data = torch.randn(32, 64, 128, 128)  # [batch_size, channels, height, width]

# 前向传播
output = batch_norm.forward(input_data, training=True)

2. Layer Normalization(LayerNorm)

2.1 LayerNorm的原理

LayerNorm的核心思想是对每个样本所有特征通道内进行归一化。具体步骤如下:

  1. 计算每个样本所有特征通道的均值 (\mu_L) 和方差 (\sigma_L^2)。
  2. 使用均值和方差对数据进行归一化:
    [
    \hat{x}_i = \frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}}
    ]
  3. 对归一化后的数据进行缩放和平移:
    [
    y_i = \gamma \hat{x}_i + \beta
    ]
    其中,(\gamma) 和 (\beta) 是可学习的参数。
2.2 LayerNorm的手写实现

以下是LayerNorm的Python实现:

class LayerNorm:
    def __init__(self, normalized_shape, eps=1e-5):
        self.normalized_shape = normalized_shape
        self.eps = eps

        # 可学习参数
        self.gamma = torch.ones(normalized_shape)
        self.beta = torch.zeros(normalized_shape)

    def forward(self, x):
        if x.dim() != 4:
            raise ValueError("输入数据必须是4D张量:[batch_size, channels, height, width]")

        batch_size, channels, height, width = x.shape

        # 计算均值和方差
        mean = x.mean(dim=(1, 2, 3), keepdim=True)  # [batch_size, 1, 1, 1]
        var = x.var(dim=(1, 2, 3), keepdim=True)    # [batch_size, 1, 1, 1]

        # 归一化
        x_hat = (x - mean) / torch.sqrt(var + self.eps)

        # 缩放和平移
        y = self.gamma.view(1, channels, 1, 1) * x_hat + self.beta.view(1, channels, 1, 1)

        return y
2.3 LayerNorm的使用示例
# 定义LayerNorm层
layer_norm = LayerNorm(normalized_shape=[64, 128, 128])

# 输入数据
input_data = torch.randn(32, 64, 128, 128)  # [batch_size, channels, height, width]

# 前向传播
output = layer_norm.forward(input_data)

3. BatchNorm与LayerNorm的对比

特性 BatchNorm LayerNorm
归一化维度 对每个特征通道在一个批次内归一化 对每个样本在所有特征通道内归一化
依赖批次大小
适用场景 CNN RNN、Transformer
计算开销 较小 较大
正则化效果

4. 总结

  • BatchNorm:对每个特征通道在一个批次内归一化,适合CNN和大批次训练。
  • LayerNorm:对每个样本在所有特征通道内归一化,适合RNN和Transformer。
  • 手写实现:通过手写实现BatchNorm和LayerNorm,我们可以更深入地理解它们的原理和细节。

希望这篇文章和代码能帮助你更好地理解BatchNorm和LayerNorm的原理和实现!


最后,如果你觉得这篇文章对你有帮助,别忘了点赞、收藏、关注三连哦!如果有任何问题,欢迎在评论区留言,我会第一时间回复!我们下期再见!

你可能感兴趣的:(自动驾驶,人工智能,计算机视觉,python)