第二十九周:文献阅读笔记(DenseNet)+ pytorch学习

第二十九周:文献阅读笔记(DenseNet)+ pytorch学习

  • 摘要
  • Abstract
  • 1、DenseNet文献阅读
    • 1.1 文献摘要
    • 1.2 文献引言
    • 1.3 DenseNets 网络
      • 1.3.1 残差网络
      • 1.3.2 密集连接
      • 1.3.3 实施细节
    • 1.4 实验
      • 1.4.1 数据集
        • 1.4.1.1 CIFAR
        • 1.4.1.2 SVHN
      • 1.4.2 模型训练
      • 1.4.3 CIFAR 和 SVHN 的分类结果
      • 1.4.4 ImageNet 上的分类结果
    • 1.5 总结
    • 1.6 DenseNet 的创新点
  • 2. DenseNet 模型代码实现(pytorch)

摘要

本周阅读了由黄高、刘壮和Laurens van der Maaten在他们2017年的论文 Densely Connected Convolutional Networks 这篇文献,DenseNet的关键思想是以前馈方式密集连接各层。在传统CNN中,信息顺序地从一层传递到下一层,每一层仅从前一层接收输入。而在DenseNet中,每一层与所有其他层都密集连接。这种密集连接有助于促进特征的重复使用,鼓励特征的传播,并在训练期间实现更好的梯度流。本文将详细介绍DenseNet网络架构。

Abstract

This week read the paper Densely Connected Convolutional Networks by Gao Huang, Zhuang Liu and Laurens van der Maaten in their 2017 paper The key idea of DenseNet is to densely connect layers in a feed-forward fashion. In a traditional CNN, information is passed sequentially from one layer to the next, with each layer receiving input only from the previous layer. In contrast, in DenseNet, each layer is densely connected to all other layers. This dense connectivity helps promote feature reuse, encourages feature propagation, and enables better gradient flow during training. In this paper, we will describe the DenseNet network architecture in detail.

1、DenseNet文献阅读

文献来源:Densely Connected Convolutional Networks

1.1 文献摘要

最近的研究表明,如果卷积网络在靠近输入的层和靠近输出的层之间包含较短的连接,则可以训练更深的网络,正如我们之前学习的ResNet残差网络。在本文中,作者将介绍密集卷积网络(DenseNet),它以前馈方式将每一层与其他每一层连接起来。传统的 L 层卷积网络有 L 个连接(每层与其后续层之间有一个连接),而我们的网络有 L(L+1) 2 个直接连接。对于每一层,所有先前层的特征图用作输入,并且其自己的特征图用作所有后续层的输入。 DenseNets 有几个引人注目的优点:它们缓解梯度消失问题,加强特征传播,鼓励特征重用,并大幅减少参数数量。我们在四个高度竞争的对象识别基准任务(CIFAR-10、CIFAR-100、SVHN 和 ImageNet)上评估了我们提出的架构。 DenseNets 比大多数最先进的网络获得了显着的改进,同时需要更少的计算来实现高性能。

1.2 文献引言

卷积神经网络(CNN)已成为视觉对象识别的主要机器学习方法。尽管它们最初是在 20 多年前引入的,但直到最近计算机硬件和网络结构的改进才使得真正的深度 CNN 的训练成为可能,就我们之前所学习的VGG、LeNet网络架构。最初的 LeNet5 由 5 层组成,VGG 有 19 层 ,直到ResNet才突破了100 层的障碍。

随着 CNN 变得越来越深,一个新的研究问题出现了:当有关输入或梯度的信息经过许多层时,当它到达网络的末端(或开始)时,它可能会消失并“被冲走”,也就是我们之前了解过的 梯度消失 问题。 ResNets 通过残差连接将信号从一层输入到下一层。随机深度 通过在训练期间随机丢弃层来缩短 ResNet,以允许更好的信息和梯度流。 FractalNets 将多个并行层序列与不同数量的卷积块重复组合以获得大的标称深度,同时在网络中保持许多短路径。尽管这些不同的方法在网络拓扑和训练过程上有所不同,但它们都有一个关键特征:它们创建从早期层到后面层的连接。

在本文中,作者提出了一种架构,将这种想法提炼为简单的连接模式:为了确保网络中各层之间的最大信息流,作者将所有层(具有匹配的特征图大小)直接相互连接。为了保持前馈性质,每个层从所有前面的层获取额外的输入,并将其自己的特征图传递到所有后续层。图 1 示意性地说明了这种布局。
第二十九周:文献阅读笔记(DenseNet)+ pytorch学习_第1张图片

与 ResNet 相比,作者将特征传递到层之前从不通过求和来组合特征;相反,通过串联连接特征来组合它们。因此,第 3 层具有输入,由所有前面的卷积块的特征图组成。它自己的特征图被传递到所有 L − l L− l Ll 后续层。这在 L 层网络中引入了 L ( L + 1 ) / 2 L(L+1) / 2 L(L+1)/2 个连接,而不是像传统架构中那样只有 L L L 个连接。由于其密集的连接模式,我们将我们的方法称为密集卷积网络(DenseNet)。

这种密集连接模式相比传统卷积网络需要更少的参数,因为不需要重新学习冗余特征图。传统的前馈架构可以被视为具有状态的算法,该状态从一层传递到另一层。每层从其前一层读取状态并写入后续层。它改变状态,但也传递需要保留的信息。 ResNets 通过加性恒等变换使信息保存变得明确,ResNets 的最新变化表明,许多层贡献很小,实际上可以在训练过程中随机丢弃,这使得 ResNets 的状态类似于(展开的)循环神经网络 ,但 ResNets 的参数数量要大得多,因为每一层都有自己的权重。作者提出的 DenseNet 架构明确区分添加到网络的信息和保留的信息。 DenseNet 层非常窄(例如,每层 12 个过滤器),仅将一小部分特征图添加到网络的“集体知识”中。

除了更好的参数效率之外,DenseNet 的一大优势是改善了整个网络的信息流和梯度,这使得它们易于训练。每层都可以直接访问损失函数和原始输入信号的梯度,从而产生 隐式深度监督。这有助于训练更深层的网络架构。此外,我们还观察到密集连接具有正则化效果,可以有效的降低过拟合。

隐式深度监督是一种深度学习中的方法,它利用任务内部的结构或其他先验信息来辅助模型的训练,而不是仅仅依赖于显式的标签或人工标注的数据。这种方法旨在通过隐含的方式引导模型学习更加有用和泛化的表示。在传统监督学习中,我们通常需要提供带有明确标签的训练数据,以让模型学习任务的显式目标。而在隐式深度监督中,可以利用任务内部的信息结构,比如数据的自相似性、任务之间的关系、数据的分布特征等,来指导模型的学习过程。

作者在 CIFAR-10、CIFAR-100、SVHN 和 ImageNet)数据集上评估 DenseNet,与精度相当的现有算法相比,DenseNet 需要的参数要少得多。

1.3 DenseNets 网络

考虑通过卷积网络的单个图像 x 0 x_0 x0。该网络包含 L 层,每层都实现非线性变换 H l ( ⋅ ) H_{l} (·) Hl(),其中对层进行索引。 H l ( ⋅ ) H_{l} (·) Hl() 可以是批归一化 (BN) 、修正线性单元 (ReLU) 、池化或卷积 (Conv) 等运算的复合函数。

1.3.1 残差网络

传统的卷积前馈网络将第 l l l 层的输出连接为第 ( l + 1 ) (l+1) (l+1) 层的输入,从而产生以下层转换: x l = H l ( x l − 1 ) x_{l} = H_{l}(x_{l-1} ) xl=Hl(xl1)。 ResNets 添加了一个跳跃连接,可以使用恒等函数绕过非线性变换: x l = H l ( x l − 1 ) + x l − 1 x_l=H_l(x_{l-1})+x_{l-1} xl=Hl(xl1)+xl1ResNets 的一个优点是梯度可以直接通过恒等函数从后面的层流到前面的层。然而,恒等函数和 H l H_l Hl 的输出是通过求和结合起来的,这可能会阻碍网络中的信息流动。

1.3.2 密集连接

为了进一步改善层之间的信息流,作者提出了一种不同的连接模式:作者引入从任何层到所有后续层的直接连接。因此,第 3 层接收所有前面层的特征图 x 0 , . . . , x l − 1 x_0,..., x_{l -1} x0,...,xl1 作为输入: x l = H l ( [ x 0 , x 1 , x 2 , . . . , x l − 1 ] ) x_l=H_l([x_0,x_1,x_2,...,x_{l-1}]) xl=Hl([x0,x1,x2,...,xl1])其中 [ x 0 , x 1 , . . . , x l − 1 ] [x_0, x_1,..., x_{l-1}] [x0,x1,...,xl1] 指的是第 0 , . . . , l − 1 0,..., l-1 0,...,l1 层中生成的特征图的串联。由于其密集的连接性,作者将该网络架构称为密集卷积网络(DenseNet)。为了便于实现,作者将等式中 H l ( ⋅ ) H_l (·) Hl() 的多个输入连接起来。作者将 H l ( ⋅ ) H_l (·) Hl() 定义为三个连续操作的复合函数:批量归一化 (BN),后跟修正线性单元 (ReLU) 和 3 × 3 卷积 (Conv )。

等式中使用的串联运算。当特征图的大小发生变化时, x l = H l ( [ x 0 , x 1 , x 2 , . . . , x l − 1 ] ) x_l=H_l([x_0,x_1,x_2,...,x_{l-1}]) xl=Hl([x0,x1,x2,...,xl1])不可行。然而,卷积网络的一个重要部分是改变特征图大小的下采样层。为了便于在该架构中进行下采样,作者将网络划分为多个密集连接的密集块,此处同残差网络一样,在残差网络中也是将网络划分为多个残差块进行连接;如下图所示。作者将块之间的层称为过渡层,它执行卷积和池化。作者实验中使用的过渡层由批量归一化层和 1×1 卷积层组成,后跟 2×2 平均池化层

1x1 卷积层通常用于减少特征图的维度(通道数),从而降低计算复杂性。在 DenseNet 中,1x1 卷积层的主要作用是在密集块内进行特征图的降维,减少通道数,以降低计算成本。同时,1x1 卷积也可以引入非线性,帮助模型更好地学习复杂的特征。
池化层(通常是平均池化或最大池化)用于降低特征图的空间维度,从而减少计算负担和提取更为显著的特征。在 DenseNet 中,池化层的作用是在密集块中引入空间下采样,从而减少每个密集块的特征图的尺寸,帮助网络适应不同尺度的特征。此外,降低特征图的维度还有助于减轻内存压力。

在这里插入图片描述
如果每个函数 H l H_l Hl 产生 k k k 个特征图,则第 l l l 层有 k 0 + k × ( l − 1 ) k_0 +k×( l-1) k0+k×(l1) 个输入的特征图,其中 k 0 k_0 k0 是输入层中的通道数。 DenseNet 和现有网络架构之间的一个重要区别是 DenseNet 可以具有非常窄的层,例如 k = 12 k = 12 k=12,超参数 k k k 称为网络的增长率。

尽管每一层仅产生 k k k 个输出特征图,但它通常具有更多的输入。 可以在每个 3×3 卷积之前引入 1×1 卷积作为中间层,以减少输入特征图的数量,从而提高计算效率。作者发现这种设计对于 DenseNet 特别有效,将具有这样一个中间层的网络称为 H l H_l Hl B N − R e L U − C o n v ( 1 × 1 ) − B N − R e L U − C o n v ( 3 × 3 ) BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) BNReLUConv(1×1)BNReLUConv(3×3)

为了进一步提高模型的紧凑性,可以减少过渡层的特征图数量。如果一个密集块包含 m m m 个特征图,我们让下面的过渡层生成 θ m θm θm 个输出特征图,其中 0 < θ ≤ 1 0 <θ ≤1 0<θ1 称为压缩因子。当 θ = 1 θ =1 θ=1 时,跨过渡层的特征图数量保持不变。我们将 θ < 1 θ<1 θ<1 的 DenseNet 称为 DenseNet-C,并在实验中设置 θ = 0.5 θ =0.5 θ=0.5。当同时使用 θ< 1 的瓶颈层和过渡层时,模型称为 DenseNet-BC。

1.3.3 实施细节

在除 ImageNet 之外的所有数据集上,作者实验中使用的 DenseNet 具有三个密集块,每个密集块具有相同数量的层。在进入第一个密集块之前,对输入图像执行 16 个(或 DenseNet-BC 增长率的两倍)输出通道的卷积。对于内核大小为 3×3 的卷积层,输入的每一侧都用一个像素进行零填充,以保持特征图大小固定。作者使用 1×1 卷积,然后使用 2×2 平均池化作为两个连续密集块之间的过渡层。在最后一个密集块的末尾,执行全局平均池化,然后附加一个 softmax 分类器。三个密集块中的特征图大小分别为 32×32、16×16 和 8×8。

在 ImageNet 上的实验中,作者在 224×224 输入图像上使用具有 4 个密集块的 DenseNet-BC 结构。初始卷积层包含 2k 个大小为 7×7、步幅为 2 的卷积;所有其他层中的特征图数量也取决于 k 的设置。作者在 ImageNet 上使用的确切网络配置如下表所示:
第二十九周:文献阅读笔记(DenseNet)+ pytorch学习_第2张图片

1.4 实验

1.4.1 数据集

1.4.1.1 CIFAR

两个 CIFAR 数据集 由 32×32 像素的彩色自然图像组成。 CIFAR-10 (C10) 由来自 10 个类别的图像组成,CIFAR-100 (C100) 由来自 100 个类别的图像组成。训练集和测试集分别包含 50,000 张和 10,000 张图像,我们提供 5,000 张训练图像作为验证集。作者采用的数据增强方案,正是广泛用于这两个数据集的标准数据增强方案(镜像/移位)。作者通过数据集名称末尾的“+”标记来表示此数据增强方案(例如,C10+)。对于预处理,作者使用通道均值和标准差对数据进行标准化。对于最终运行,作者使用全部 50,000 个训练图像,并在训练结束时报告最终测试错误。

1.4.1.2 SVHN

街景门牌号 (SVHN) 数据集 包含 32×32 彩色数字图像。训练集中有 73257 张图像,测试集中有 26032 张图像,还有 531131 张图像用于额外训练。作者使用所有训练数据而不进行任何数据增强,并从训练集中分割出包含 6,000 张图像的验证集。作者在训练期间选择验证误差最低的模型并报告测试误差。作者按照将像素值除以255,这样它们就在[0, 1]范围内。

1.4.2 模型训练

所有网络均使用随机梯度下降(SGD)进行训练。在 CIFAR 和 SVHN 上,作者分别使用批量大小 64 进行 300 和 40 个 epoch 的训练。初始学习率设置为0.1,在训练epoch总数的50%和75%时除以10。在 ImageNet 上,作者训练模型 90 个 epoch,批量大小为 256。学习率最初设置为 0.1,并在第 30 和 60 个 epoch 降低 10 倍。由于 GPU 内存限制,我们最大的模型(DenseNet-161) )使用小批量大小 128 进行训练。为了补偿较小的批量大小,我们将该模型训练 100 个时期,并在第 90 个时期将学习率除以 10。作者使用 1 0 − 4 10^{−4} 104 的权重衰减和 0.9 的 Nesterov 动量。对于没有数据增强的三个数据集,即C10、C100和SVHN,作者在每个卷积层(第一个除外)之后添加一个dropout层,并将dropout率设置为0.2。对于每个任务和模型设置,测试错误仅评估一次。

1.4.3 CIFAR 和 SVHN 的分类结果

第二十九周:文献阅读笔记(DenseNet)+ pytorch学习_第3张图片

最明显的趋势可能源自表 2 的底行,该表显示 L = 190 且 k = 40 的 DenseNet-BC 在所有 CIFAR 数据集上始终优于现有的最先进技术。其在 C10+ 上的错误率为 3.46%,在 C100+ 上的错误率为 17.18%,明显低于宽 ResNet 架构所实现的错误率

1.4.4 ImageNet 上的分类结果

为了能和ResNet网络进行有效对比,作者将 ResNet 模型替换为 DenseNet-BC 网络,并保持所有实验设置与 ResNet 中使用的设置完全相同,由于 GPU 内存限制,我们最大的 DenseNet 模型是使用 128 的小批量进行训练的;我们训练这个模型 100 个 epoch,在 90 个 epoch 之后第三次学习率下降,以补偿较小的批量大小。
第二十九周:文献阅读笔记(DenseNet)+ pytorch学习_第4张图片

图中显示的结果表明,DenseNet 的性能与最先进的 ResNet 相当,同时需要显着更少的参数和计算来实现可比较的性能。

1.5 总结

作者在该文献中提出了一种新的卷积网络架构——密集卷积网络(DenseNet)。它引入了具有相同特征图大小的任意两层之间的直接连接。同时作者证明了 DenseNet 可以自然地扩展到数百层,同时没有表现出优化困难。在实验中,DenseNet 往往会随着参数数量的增加而不断提高准确性,而不会出现任何性能下降或过度拟合的迹象。此外,DenseNet 需要更少的参数和更少的计算来实现最先进的性能。在遵循简单的连接规则的同时,DenseNet 自然地融合了恒等映射、深度监督和多样化深度的特性。

1.6 DenseNet 的创新点

  1. 密集连接(Dense Connections): DenseNet 引入了密集连接的概念,将每一层的输出与所有先前层的输出连接在一起。这种密集连接的设计增加了信息流动的路径,使得模型更容易共享特征并更好地利用网络中的参数。相比于传统的网络结构,密集连接减轻了梯度消失问题,有助于更深层次的网络训练。
  2. 减缓信息丢失: 由于每一层的输出都被直接传递到后续层,信息在网络中更为充分地传递。这有助于减缓信息丢失,使网络更好地保留和利用先前层的特征,从而提高模型的表示能力。
  3. 参数和计算效率: 由于密集连接的存在,每一层只需关注前一层的输出,减少了网络中每一层所需学习的参数数量,提高了参数效率。此外,密集连接还允许特征重用,减少了冗余计算,提高了计算效率。
  4. 尺度适应性: DenseNet 的密集连接有助于网络适应不同尺度的特征。通过将每一层的输出与所有先前层的输出连接,网络可以更好地捕捉各种尺度的特征,有助于处理复杂的图像结构。
  5. 消除网络深度的问题: 传统深度网络在网络加深时可能会面临梯度消失或梯度爆炸等问题。由于密集连接的存在,DenseNet 在更深的网络结构下仍然能够有效地训练,克服了一些传统深度网络的问题。

2. DenseNet 模型代码实现(pytorch)

import re
from typing import Any, List, Tuple
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensor


class _DenseLayer(nn.Module):
    def __init__(self,
                 input_c: int,
                 growth_rate: int,
                 bn_size: int,
                 drop_rate: float,
                 memory_efficient: bool = False):
        super(_DenseLayer, self).__init__()

        self.add_module("norm1", nn.BatchNorm2d(input_c))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", nn.Conv2d(in_channels=input_c,
                                           out_channels=bn_size * growth_rate,
                                           kernel_size=1,
                                           stride=1,
                                           bias=False))
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(bn_size * growth_rate,
                                           growth_rate,
                                           kernel_size=3,
                                           stride=1,
                                           padding=1,
                                           bias=False))
        self.drop_rate = drop_rate
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concat_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concat_features)))
        return bottleneck_output

    @staticmethod
    def any_requires_grad(inputs: List[Tensor]) -> bool:
        for tensor in inputs:
            if tensor.requires_grad:
                return True

        return False

    @torch.jit.unused
    def call_checkpoint_bottleneck(self, inputs: List[Tensor]) -> Tensor:
        def closure(*inp):
            return self.bn_function(inp)

        return cp.checkpoint(closure, *inputs)

    def forward(self, inputs: Tensor) -> Tensor:
        if isinstance(inputs, Tensor):
            prev_features = [inputs]
        else:
            prev_features = inputs

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("memory efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features,
                                     p=self.drop_rate,
                                     training=self.training)

        return new_features


class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(self,
                 num_layers: int,
                 input_c: int,
                 bn_size: int,
                 growth_rate: int,
                 drop_rate: float,
                 memory_efficient: bool = False):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(input_c + i * growth_rate,
                                growth_rate=growth_rate,
                                bn_size=bn_size,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient)
            self.add_module("denselayer%d" % (i + 1), layer)

    def forward(self, init_features: Tensor) -> Tensor:
        features = [init_features]
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)


class _Transition(nn.Sequential):
    def __init__(self,
                 input_c: int,
                 output_c: int):
        super(_Transition, self).__init__()
        self.add_module("norm", nn.BatchNorm2d(input_c))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(input_c,
                                          output_c,
                                          kernel_size=1,
                                          stride=1,
                                          bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    """
    Densenet-BC model class for imagenet

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient
    """

    def __init__(self,
                 growth_rate: int = 32,
                 block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
                 num_init_features: int = 64,
                 bn_size: int = 4,
                 drop_rate: float = 0,
                 num_classes: int = 1000,
                 memory_efficient: bool = False):
        super(DenseNet, self).__init__()

        # first conv+bn+relu+pool
        self.features = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ("norm0", nn.BatchNorm2d(num_init_features)),
            ("relu0", nn.ReLU(inplace=True)),
            ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # each dense block
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                input_c=num_features,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient)
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features = num_features + num_layers * growth_rate

            if i != len(block_config) - 1:
                trans = _Transition(input_c=num_features,
                                    output_c=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
                num_features = num_features // 2

        # finnal batch norm
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))

        # fc layer
        self.classifier = nn.Linear(num_features, num_classes)

        # init weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out


def densenet121(**kwargs: Any) -> DenseNet:
    # Top-1 error: 25.35%
    # 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth'
    return DenseNet(growth_rate=32,
                    block_config=(6, 12, 24, 16),
                    num_init_features=64,
                    **kwargs)


def densenet169(**kwargs: Any) -> DenseNet:
    # Top-1 error: 24.00%
    # 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth'
    return DenseNet(growth_rate=32,
                    block_config=(6, 12, 32, 32),
                    num_init_features=64,
                    **kwargs)


def densenet201(**kwargs: Any) -> DenseNet:
    # Top-1 error: 22.80%
    # 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth'
    return DenseNet(growth_rate=32,
                    block_config=(6, 12, 48, 32),
                    num_init_features=64,
                    **kwargs)


def densenet161(**kwargs: Any) -> DenseNet:
    # Top-1 error: 22.35%
    # 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth'
    return DenseNet(growth_rate=48,
                    block_config=(6, 12, 36, 24),
                    num_init_features=96,
                    **kwargs)


def load_state_dict(model: nn.Module, weights_path: str) -> None:
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = torch.load(weights_path)

    num_classes = model.classifier.out_features
    load_fc = num_classes == 1000

    for key in list(state_dict.keys()):
        if load_fc is False:
            if "classifier" in key:
                del state_dict[key]

        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict, strict=load_fc)
    print("successfully load pretrain-weights.")

你可能感兴趣的:(笔记,pytorch,学习)