torch.nn
概述与设计哲学torch.nn
是 PyTorch 提供的一个模块化工具集,专为构建和训练神经网络设计。其核心设计理念包括:
nn.Module
基类,用户可以像搭积木一样组合各种层、激活函数和损失函数,构建复杂模型。torch.nn
的核心组件包括:
nn.Linear
、nn.Conv2d
等。nn.ReLU
、nn.Softmax
等。nn.MSELoss
、nn.CrossEntropyLoss
等。nn.Sequential
、nn.ModuleList
等。nn.Dropout
、nn.BatchNorm2d
等。接下来,将逐一深入讲解这些组件,辅以数学公式、代码示例和实际应用。
nn.Module
:神经网络的核心基类nn.Module
是所有神经网络模块的基类,无论是简单的全连接层还是复杂的 Transformer 模型,都继承自它。其核心功能包括:
torch.nn.Parameter
(如权重和偏置),并通过 parameters()
或 named_parameters()
方法访问。forward
方法实现数据流。.to(device)
将模型和参数移动到指定设备(如 GPU)。.train()
和 .eval()
切换模式,影响 Dropout 和 BatchNorm 的行为。数学背景:
nn.Module
负责管理 W W W 和 b b b,并在反向传播时计算梯度 ∂ L ∂ W \frac{\partial L}{\partial W} ∂W∂L 和 ∂ L ∂ b \frac{\partial L}{\partial b} ∂b∂L。当定义一个继承 nn.Module
的类时,PyTorch 会:
nn.Linear
)和参数(如 nn.Parameter
)。model.parameters()
时,递归收集所有可学习的参数。model(x)
时,自动执行 forward
方法。代码示例:实现一个简单的两层神经网络
import torch
import torch.nn as nn
class TwoLayerNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(TwoLayerNet, self).__init__()
self.layer1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.layer1(x)
x = self.relu(x)
x = self.layer2(x)
return x
# 实例化模型
model = TwoLayerNet(input_size=10, hidden_size=20, output_size=2)
x = torch.randn(32, 10) # 批次大小32,输入10维
output = model(x) # 输出形状为 (32, 2)
print(model)
输出:
TwoLayerNet(
(layer1): Linear(in_features=10, out_features=20, bias=True)
(relu): ReLU()
(layer2): Linear(in_features=20, out_features=2, bias=True)
)
nn.Module
自动管理参数。例如,nn.Linear
的权重和偏置是 nn.Parameter
类型的张量,PyTorch 会跟踪它们并在优化时更新。
查看参数:
for name, param in model.named_parameters():
print(f"Parameter: {
name}, Shape: {
param.shape}")
输出:
Parameter: layer1.weight, Shape: torch.Size([20, 10])
Parameter: layer1.bias, Shape: torch.Size([20])
Parameter: layer2.weight, Shape: torch.Size([2, 20])
Parameter: layer2.bias, Shape: torch.Size([2])
通过继承 nn.Module
,可以实现自定义逻辑。例如,添加一个残差连接:
class ResidualBlock(nn.Module):
def __init__(self, dim):
super(ResidualBlock, self).__init__()
self.fc1 = nn.Linear(dim, dim)
self.fc2 = nn.Linear(dim, dim)
self.relu = nn.ReLU()
def forward(self, x):
residual = x
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out += residual # 残差连接
return out
以下是 torch.nn
中常见的层,包含数学原理、参数说明、代码示例和应用场景。
nn.Linear
数学原理:
参数:
in_features
:输入维度。out_features
:输出维度。bias
:是否包含偏置(默认 True)。代码示例:
linear = nn.Linear(10, 5)
x = torch.randn(32, 10) # 批次大小32
y = linear(x) # 输出形状 (32, 5)
print(y.shape)
应用场景:
注意事项:
in_features
。nn.Conv2d
数学原理: