PyTorch 之 torchvision.models 使用

PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:

torchvision.datasets

torchvision.models

torchvision.transforms

官网介绍   源码

-----------------------------------------------------------------------------------------------------------------

本文介绍 torchvision.models 如何使用。以 vgg16为例子

1) 导入预训练模型:

import torchvision
model = torchvision.models.vgg16(pretrained=True)

2) 只导入网络结构,不导入参数:

model = torchvision.models.vgg16(pretrained=False) #主要是这里改为False

3) 由于 pretrained 参数默认是 False,所以 2) 等价于:

model = torchvision.models.vgg16()

 

你可能感兴趣的:(Ubuntu+python)