【Torch】nn.Dropout算法详解

1. 定义

nn.Dropout 是 PyTorch 中用于 防止神经网络过拟合 的正则化层。其核心思想是在训练阶段随机“丢弃”(置零)部分神经元的输出,以减少网络对特定神经元的过度依赖;在推理阶段则保持所有神经元输出不变。

2. 输入与输出

  • 输入(Input)
    • 任意形状的浮点张量(如 torch.float32torch.float64 等),常见于全连接层或卷积层的激活输出。
  • 输出(Output)
    • 与输入张量形状、dtype 完全相同。
    • 训练模式下:随机将部分元素置为 0,并对剩余元素按比例放大。
    • 评估模式下model.eval()):直接输出输入,等价于恒等映射。

3. 底层原理

  1. 随机掩码生成
    对输入张量 x x x的每个元素,按给定概率 p p p独立地生成掩码 m m m
    m i ∼ B e r n o u l l i ( 1 − p ) m_i \sim \mathrm{Bernoulli}(1 - p) miBernoulli(1p)
  2. 丢弃与缩放
    训练时输出
    y i = m i 1 − p    x i y_i = \frac{m_i}{1 - p}\; x_i yi=1pmixi
    • m i = 0 m_i=0 mi=0,则对应 y i = 0 y_i=0 yi=0(丢弃);
    • m i = 1 m_i=1 mi=1,则对应 y i = x i 1 − p y_i=\frac{x_i}{1-p} yi=1pxi(为保持期望不变,对未丢弃元素按 1 / ( 1 − p ) 1/(1-p) 1/(1p)缩放)。

模式切换

  • model.train() 下启用丢弃与缩放;
  • model.eval() 下跳过丢弃与缩放,直接 y=x

4. 构造函数参数详解

参数 类型 & 默认 说明
p float, 默认 0.5 丢弃概率,表示每个神经元被置零的概率。取值区间 ([0,1))。
inplace bool, 默认 False 是否对输入张量进行就地修改。
True,可节省显存,但可能影响梯度检查和反向传播可读性。

5. 使用示例

import torch
import torch.nn as nn

# 1. 定义 Dropout 层
drop = nn.Dropout(p=0.3)    # 丢弃概率 30%

# 2. 构造示例输入
x = torch.arange(12, dtype=torch.float32).reshape(3,4)
print("原始输入:\n", x)

# 3. 训练模式下(随机丢弃 + 缩放)
drop.train()
y_train = drop(x)
print("\n训练模式下输出(部分元素被置零,其余放大):\n", y_train)

# 4. 评估模式下(恒等映射)
drop.eval()
y_eval = drop(x)
print("\n评估模式下输出(与输入相同):\n", y_eval)

输出示例(训练模式):

训练模式下输出(部分元素被置零,其余放大):
 tensor([[0.0000, 1.4286, 2.8571, 4.2857],
         [5.7143, 0.0000, 8.5714, 11.4286],
         [0.0000, 14.2857, 0.0000, 17.1429]])
  • 可见,约 30% 元素置为 0,其余元素按 1 / ( 1 − 0.3 ) ≈ 1.4286 1/(1-0.3)\approx1.4286 1/(10.3)1.4286放大。

6. 注意事项

  1. 模式切换
    • 务必在训练阶段调用 model.train(),在推理阶段调用 model.eval();否则 Dropout 行为与预期不符。
  2. 丢弃概率选择
    • 常用值:全连接层一般取 0.3–0.5,卷积层或 Transformer 等更浅层可取 0.1–0.2;需通过验证集调优。
  3. 就地修改
    • inplace=True 可节省内存,但会破坏原始输入,且与某些梯度检查工具(如 torch.autograd.gradcheck)兼容性差。
  4. 与 BatchNorm 等层的配合
    • 在某些架构(如 ResNet)中,卷积后先做 BatchNorm、激活,再做 Dropout;而对于自注意力模型(Transformer),常在子层输出后或前馈网络中使用。
  5. 序列模型
    • 对 RNN 类模型,建议使用 nn.Dropout 结合 nn.RNN 中的 dropout 参数,或使用专门的 LockedDropout(固定同一时间步的掩码)以保证序列一致性。
  6. 变体
    • PyTorch 还提供 Dropout2d(对 Channel 维进行丢弃)、Dropout3dFeatureAlphaDropout(用于 SELU)等,针对不同任务场景可酌情选择。

你可能感兴趣的:(深度学习,算法)