Python打卡DAY36

DAY36:复习日

恩师@浙大疏锦行

在PyTorch中,nn.Model 是所有神经网络模块的基类,为构建和训练神经网络提供了丰富的方法,如下:

1. 模型构建与参数管理

  • __init__ 方法
    • 功能:用于初始化神经网络模块的参数和子模块。在自定义网络时,通常会重写此方法来定义网络的结构。
    • 细节解释:在 __init__ 方法中,可以定义各种层,如卷积层、全连接层等。这些层会被自动注册为子模块,方便后续管理。
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)

net = SimpleNet()

parameters 方法

  • 功能:返回一个包含模型所有可学习参数的迭代器。在训练模型时,通常将其作为优化器的输入。
  • 细节解释:该方法会递归地遍历模型的所有子模块,并返回它们的参数。
import torch.optim as optim

optimizer = optim.SGD(net.parameters(), lr=0.01)

2. 前向传播

  • forward 方法
    • 功能:定义模型的前向传播过程,即输入数据如何通过网络得到输出。在自定义网络时,必须重写此方法。
    • 细节解释:前向传播是神经网络计算输出的核心步骤,通过定义 forward 方法,可以灵活地组合各种层和操作。
import torch

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

net = SimpleNet()
input_tensor = torch.randn(1, 10)
output = net(input_tensor)

3. 模型状态管理

  • train 方法
    • 功能:将模型设置为训练模式。在训练过程中,某些层(如 DropoutBatchNorm)会有不同的行为。
    • 细节解释:调用 train 方法会将模型及其所有子模块的 training 属性设置为 True
net.train()

eval 方法

  • 功能:将模型设置为评估模式。在评估或推理阶段,通常需要将模型设置为该模式。
  • 细节解释:调用 eval 方法会将模型及其所有子模块的 training 属性设置为 False
net.eval()
with torch.no_grad():
    output = net(input_tensor)

4. 模型保存与加载

  • state_dict 方法
    • 功能:返回一个包含模型所有可学习参数的字典。可以将其保存到文件中,以便后续加载。
    • 细节解释state_dict 只包含模型的参数,不包含模型的结构。
torch.save(net.state_dict(), 'model.pth')

 load_state_dict 方法

  • 功能:从一个 state_dict 中加载模型的参数。通常用于恢复之前训练好的模型。
  • 细节解释:加载时,需要确保 state_dict 的键与模型的参数名匹配。
net.load_state_dict(torch.load('model.pth'))

你可能感兴趣的:(Python打卡DAY36)