Pytorch实战4——ResNet网络对MNIST数据集分类

目录

一、ResNet(Residual Network) 残差网络

二、代码编程实战对mnist数据集分类


 

一、ResNet(Residual Network) 残差网络

       主要用于解决梯度消失问题 ,H(x)=F(x)+x,F表示残差部分,残差部分与上一层的输出相加就构成了下一层的输入,这整体结构也就被称为残差块(Residual Block)。

Pytorch实战4——ResNet网络对MNIST数据集分类_第1张图片

 上图左是普通的网络,在这个网络中,训练过程中可能会出现梯度消失的情况(梯度=0),权重参数W = W - 入g, 当梯度g趋于0时,W未发生改变。上图右是ResNet残差网络。

 上图右是ResNet残差网络,x为输入值,F(X)是经过第一层线性变化并激活后的输出,该图表示在残差网络中,第二层进行线性变化之后激活之前,F(x)加入了这一层输入值X,然后再进行激活后输出。在第二层输出值激活前加入X,这条路径称作shortcut连接。

ResNet最终更新某一个节点的参数时,由于H(x)=F(x)+x,求导后,不管括号内右边部分的F(x)求导有多小,始终有x求导后的1存在,不会发生梯度消失的情况。

二、代码编程实战对mnist数据集分类

代码基于之前实战2的卷积神经网络编程,只需要把实战2的代码中设计模型部分替换成ResNet网络的即可,其它部分代码不变,替换的模型如下

#  Residual Network(ResNet)网络模型
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.channels = channels
        self.conv1 = nn.Conv2d(channels, channels,
        kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels,
        kernel_size=3, padding=1)
    def forward(self, x):
        y = F.relu(self.conv1(x))
        y = self.conv2(y)
        return F.relu(x + y) #两次卷积后的输出y,加上两次卷积前的输入x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.mp = nn.MaxPool2d(2)
        self.rblock1 = ResidualBlock(16) #ResNet网络不改变输入输出维度
        self.rblock2 = ResidualBlock(32)
        self.fc = nn.Linear(512, 10)
    def forward(self, x):
        in_size = x.size(0)
        x = self.mp(F.relu(self.conv1(x)))
        x = self.rblock1(x)
        x = self.mp(F.relu(self.conv2(x)))
        x = self.rblock2(x)
        x = x.view(in_size, -1)
        x = self.fc(x)
        return x
     
model = Net()

运行结果:

[1,   300] loss: 0.526
[1,   600] loss: 0.152
[1,   900] loss: 0.110
Accuracy on test set: 97 %
[2,   300] loss: 0.090
[2,   600] loss: 0.080
[2,   900] loss: 0.072
Accuracy on test set: 98 %

使用残差网络,在只训练了2轮的情况下,准确率达到了98%,错误率降低了三分之一。

你可能感兴趣的:(Pytorch,pytorch,分类,深度学习)