机器学习大白话杂谈(3)——Pytorch模型加载问题。

一、模型加载大致分类

Pytorch模型的加载方式支持两种:一种是使用Pytorch自带的模型,另外一种是程序员自己编写的模型。我们慢慢介绍:

一、pytorch官方自带的模型加载

       官方自带的模型封装在torchvision.models类中,在使用时,需要先导出

import torchvision.models as models

       在这个类中,截至pytorch 0.4.1版本的时候,其主要提供的模型有以下几种:(要是几种模型不太清楚的,可以看下我之前的文章:传送门链接)

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
  • Inception v3

       在这我们使用ResNet18做一个classification为例,来简单介绍Pytorch加载自带的模型

resnet18 = models.resnet18()  #加载模型

        当然,模型中的权值开始全都是随机的,原文:You can construct a model with random weights  by calling its constructor。

    同时Pytorch也提供预训练模式,调用的方法也很简单,只需要在加载模型中,加入以下参数

resnet18 = models.resnet18(pretrained=True) #加载模型,并设为预训练模式

        值得注意的是,当你将模型设置为预训练模式的时候,模型输入的图片最好是经过相同方式归一化的3通道数据,即(3*H*W)H与W至少是224,图像必须被加载到(0,1)的范围内,然后再归一化到mean = [0.485, 0.456, 0.406]std=                [0.229,0.224,0.225]。可以使用pytorch自带的归一化函数

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

        Pytorch加载模型后,模型有两个状态,model.train()or   model.eval()

        当你在训练模型的时候,将模型置为model.train()

        当你在测试模型的时候,将模型置为 model.eval(),设置为此种模式的时候,模型内部的权值不会改变.

 

二、自带模型汇总

Alexnet

torchvision.models.alexnet(pretrained=False, **kwargs)

VGG

torchvision.models.vgg11(pretrained=False, **kwargs)
torchvision.models.vgg11_bn(pretrained=False, **kwargs)
torchvision.models.vgg13(pretrained=False, **kwargs)
torchvision.models.vgg13_bn(pretrained=False, **kwargs)
torchvision.models.vgg16(pretrained=False, **kwargs)
torchvision.models.vgg16_bn(pretrained=False, **kwargs)
torchvision.models.vgg19(pretrained=False, **kwargs)
torchvision.models.vgg19_bn(pretrained=False, **kwargs)

ResNet

torchvision.models.resnet18(pretrained=False, **kwargs
torchvision.models.resnet34(pretrained=False, **kwargs)
torchvision.models.resnet50(pretrained=False, **kwargs)
torchvision.models.resnet101(pretrained=False, **kwargs)
torchvision.models.resnet152(pretrained=False, **kwargs)

SqueezeNet

torchvision.models.squeezenet1_0(pretrained=False, **kwargs)
torchvision.models.squeezenet1_1(pretrained=False, **kwargs)

DenseNet

torchvision.models.densenet121(pretrained=False, **kwargs)
torchvision.models.densenet169(pretrained=False, **kwargs)
torchvision.models.densenet161(pretrained=False, **kwargs)
torchvision.models.densenet201(pretrained=False, **kwargs)

Inception v3

torchvision.models.inception_v3(pretrained=False, **kwargs)

下一章打算写模型编写问题,有啥问题可以留言。

你可能感兴趣的:(深度学习)