嗨,大家好,我是小森( ﹡ˆoˆ﹡ )! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。
易编橙:一个帮助编程小伙伴少走弯路的终身成长社群!
上一部分我们自己通过torch的方法完成反向传播和参数更新,在Pytorch中预设了一些更加灵活简单的对象,让我们来构造模型、定义损失,优化损失等;那么接下来,我们一起来了解一下其中常用的API!
nn.Module
nn.Module
是 PyTorch 框架中用于构建所有神经网络模型的基类。在 PyTorch 中,几乎所有的神经网络模块(如层、卷积层、池化层、全连接层等)都继承自 nn.Module
。这个类提供了构建复杂网络所需的基本功能,如参数管理、模块嵌套、模型的前向传播等。
当我们自定义网络的时候,有两个方法需要特别注意:
__init__
需要调用super
方法,继承父类的属性和方法
farward
方法必须实现,用来定义我们的网络的向前计算的过程
用前面的y = wx+b
的模型举例如下:
from torch import nn
class Lr(nn.Module):
def __init__(self):
super(Lr, self).__init__() # 继承父类init的参数
self.linear = nn.Linear(1, 1)
def forward(self, x):
out = self.linear(x)
return out
nn.Linear
为torch预定义好的线性模型,也被称为全链接层,传入的参数为输入的