Pytorch模型的加载方式支持两种:一种是使用Pytorch自带的模型,另外一种是程序员自己编写的模型。我们慢慢介绍:
一、pytorch官方自带的模型加载
官方自带的模型封装在torchvision.models类中,在使用时,需要先导出
import torchvision.models as models
在这个类中,截至pytorch 0.4.1版本的时候,其主要提供的模型有以下几种:(要是几种模型不太清楚的,可以看下我之前的文章:传送门链接)
在这我们使用ResNet18做一个classification为例,来简单介绍Pytorch加载自带的模型
resnet18 = models.resnet18() #加载模型
当然,模型中的权值开始全都是随机的,原文:You can construct a model with random weights by calling its constructor。
同时Pytorch也提供预训练模式,调用的方法也很简单,只需要在加载模型中,加入以下参数
resnet18 = models.resnet18(pretrained=True) #加载模型,并设为预训练模式
值得注意的是,当你将模型设置为预训练模式的时候,模型输入的图片最好是经过相同方式归一化的3通道数据,即(3*H*W)H与W至少是224,图像必须被加载到(0,1)的范围内,然后再归一化到mean = [0.485, 0.456, 0.406]且std= [0.229,0.224,0.225]。可以使用pytorch自带的归一化函数
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
Pytorch加载模型后,模型有两个状态,model.train()
or model.eval()
当你在训练模型的时候,将模型置为model.train()
当你在测试模型的时候,将模型置为 model.eval(),
设置为此种模式的时候,模型内部的权值不会改变.
Alexnet
torchvision.models.alexnet(pretrained=False, **kwargs)
VGG
torchvision.models.vgg11(pretrained=False, **kwargs)
torchvision.models.vgg11_bn(pretrained=False, **kwargs)
torchvision.models.vgg13(pretrained=False, **kwargs)
torchvision.models.vgg13_bn(pretrained=False, **kwargs)
torchvision.models.vgg16(pretrained=False, **kwargs)
torchvision.models.vgg16_bn(pretrained=False, **kwargs)
torchvision.models.vgg19(pretrained=False, **kwargs)
torchvision.models.vgg19_bn(pretrained=False, **kwargs)
ResNet
torchvision.models.resnet18(pretrained=False, **kwargs
torchvision.models.resnet34(pretrained=False, **kwargs)
torchvision.models.resnet50(pretrained=False, **kwargs)
torchvision.models.resnet101(pretrained=False, **kwargs)
torchvision.models.resnet152(pretrained=False, **kwargs)
SqueezeNet
torchvision.models.squeezenet1_0(pretrained=False, **kwargs)
torchvision.models.squeezenet1_1(pretrained=False, **kwargs)
DenseNet
torchvision.models.densenet121(pretrained=False, **kwargs)
torchvision.models.densenet169(pretrained=False, **kwargs)
torchvision.models.densenet161(pretrained=False, **kwargs)
torchvision.models.densenet201(pretrained=False, **kwargs)
Inception v3
torchvision.models.inception_v3(pretrained=False, **kwargs)
下一章打算写模型编写问题,有啥问题可以留言。