PyTorch 的 torch.nn 模块学习

torch.nn 是 PyTorch 中专门用于构建和训练神经网络的模块。它的整体架构分为几个主要部分,每部分的原理、要点和使用场景如下:

1. nn.Module

  • 原理和要点nn.Module 是所有神经网络组件的基类。任何神经网络模型都应该继承 nn.Module,并实现其 forward 方法。
  • 使用场景:用于定义和管理神经网络模型,包括层、损失函数和自定义的前向传播逻辑。
  • 主要 API 和使用场景
    __init__: 初始化模型参数。
    forward: 定义前向传播逻辑。
    parameters: 返回模型的所有参数。
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.linear(x)

model = MyModel()
print(model)

2. Layers(层)

  • 原理和要点:层是神经网络的基本构建块,包括全连接层、卷积层、池化层等。每种层执行特定类型的操作,并包含可学习的参数。
  • 使用场景:用于构建神经网络的各个组成部分,如特征提取、降维等。
2.1 nn.Linear(全连接层)
linear = nn.Linear(10, 5)
input = torch.randn(1, 10)
output = linear(input)
print(output)
2.2 nn.Conv2d(二维卷积层)
conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
input = torch.randn(1, 1, 5, 5)
output = conv(input)
print(output)
2.3 nn.MaxPool2d(二维最大池化层)
maxpool = nn.MaxPool2d(kernel_size=2)
input = torch.randn(1, 1, 4, 4)
output = maxpool(input)
print(output)

3. Loss Functions(损失函数)

  • 原理和要点:损失函数用于衡量模型预测与真实值之间的差异,指导模型优化过程。
  • 使用场景:用于计算训练过程中需要最小化的误差。
3.1 nn.MSELoss(均方误差损失)
mse_loss = nn.MSELoss()
input = torch.randn(3, 5)
target = torch.randn(3, 5)
loss = mse_loss(input, target)
print(loss)
3.2 nn.CrossEntropyLoss(交叉熵损失)

你可能感兴趣的:(PyTorch 的 torch.nn 模块学习)