这段代码的主要目的是加载和修改 VGG16 模型,并打印模型的结构。以下是代码的详细解释:
Python复制
import torchvision
import torch
torchvision
是 PyTorch 的计算机视觉扩展库,提供了常用的预训练模型和数据集。
torch
是 PyTorch 核心库,用于构建和训练深度学习模型。
Python复制
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true
:加载带有预训练权重的 VGG16 模型。预训练权重通常是在大规模数据集(如 ImageNet)上训练得到的。
vgg16_false
:加载不带预训练权重的 VGG16 模型,初始权重是随机初始化的。
Python复制
print(vgg16_true)
这行代码会输出 VGG16 模型的结构,包括卷积层(features
)和全连接层(classifier
)的详细信息。
Python复制
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(), download=True)
torchvision.datasets.CIFAR10
:加载 CIFAR-10 数据集。
"../data"
:数据集存储的路径。
train=True
:加载训练集。
transform=torchvision.transforms.ToTensor()
:将图像数据转换为 PyTorch 张量。
download=True
:如果数据集不存在,会自动下载。
Python复制
vgg16_true.classifier.add_module('add_linear', torch.nn.Linear(1000, 10))
vgg16_true.classifier
:VGG16 模型的全连接层部分。
add_module('add_linear', torch.nn.Linear(1000, 10))
:向分类器部分添加一个新的线性层。
'add_linear'
:新模块的名称。
torch.nn.Linear(1000, 10)
:定义一个线性层,输入维度为 1000(对应原分类器的输出维度),输出维度为 10(对应 CIFAR-10 的 10 个类别)。
Python复制
print(vgg16_true)
这行代码会输出修改后的 VGG16 模型的结构,可以看到新增的 add_linear
模块。
代码的主要目的是加载预训练的 VGG16 模型,并修改其分类器部分以适配 CIFAR-10 数据集。
修改后的模型可以直接用于 CIFAR-10 的分类任务,而无需重新训练整个 VGG16 模型。
这段代码展示了两种不同的模型加载方式,并打印了模型的结构。以下是代码的详细解释:
Python复制
model = torch.load("vgg16_method1.pth")
使用 torch.load
加载保存的模型文件 "vgg16_method1.pth"
。
这种方式加载的模型是完整的模型对象,包括模型结构和权重。
通常,这种保存方式是通过 torch.save(model, path)
保存的,保存了整个模型。
Python复制
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
先实例化一个未预训练的 VGG16 模型。
使用 load_state_dict
方法加载保存的模型权重文件 "vgg16_method2.pth"
。
这种方式加载的模型是通过 torch.save(model.state_dict(), path)
保存的模型权重。
Python复制
print(model)
print(vgg16)
分别打印两种加载方式加载的模型结构。
预期输出是两个模型的详细结构,包括各层的参数和连接方式。
保存方式不同:
torch.save(model, path)
保存的是完整的模型对象(包括结构和权重)。
torch.save(model.state_dict(), path)
保存的是模型的权重字典。
加载方式不同:
方式1:直接加载完整的模型对象。
方式2:需要先实例化模型,再加载权重。
适用场景:
如果需要迁移学习或微调模型,通常使用方式2(加载权重),因为这样可以灵活修改模型结构。
如果模型结构固定,可以直接使用方式1加载整个模型。
print(model)
:输出通过方式1加载的完整模型结构。
print(vgg16)
:输出通过方式2加载的模型结构,预期与方式1的结构一致(假设权重一致)。
方式1 适用于快速加载和使用完整的模型对象。
方式2 适用于需要灵活修改模型结构或进行迁移学习的场景。
这段代码展示了如何使用 PyTorch 的 torch.save — PyTorch 2.6 documentation 函数以两种不同的方式保存模型。以下是代码的详细解释:
导入必要的库
Python复制
import torch
import torchvision
torch
: PyTorch 核心库,用于构建和训练深度学习模型。
torchvision
: PyTorch 的计算机视觉扩展库,提供了常用的预训练模型。
加载未预训练的 VGG16 模型
Python复制
vgg16 = torchvision.models.vgg16(pretrained=False)
torchvision.models.vgg16
: 加载 VGG16 模型。
pretrained=False
: 表示不加载预训练权重,模型的权重是随机初始化的。
保存方式 1:模型结构 + 参数
Python复制
torch.save(vgg16, "vgg16_method1.pth")
torch.save(vgg16, "vgg16_method1.pth")
: 直接保存整个模型对象。
包括模型的结构(例如,层的设计)和权重(参数)。
这种方式适合保存完整的模型,便于直接加载和使用。
保存方式 2:只保存模型参数(权重)
Python复制
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
vgg16.state_dict()
: 获得模型的权重字典(仅保存参数)。
torch.save(...)
: 保存权重字典到文件。
这种方式更适合迁移学习或模型微调时使用,因为可以灵活地将权重加载到不同的模型结构中。
特性 | 方式 1 | 方式 2 |
---|---|---|
保存内容 | 模型结构 + 参数 | 仅参数(权重) |
文件大小 | 较大(保存结构信息) | 较小(仅保存参数) |
适用场景 | 快速导出和加载完整模型 | 迁移学习、模型微调 |
局限性 | 模型结构依赖原环境 | 加载时需预先定义模型结构 |
方式 1:
Python复制
model = torch.load("vgg16_method1.pth")
直接加载整个模型对象。
方式 2:
Python复制
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
需要先定义模型结构,再加载权重。
方式 1:适用于快速保存和分享完整的模型,不涉及模型结构调整。
方式 2:适用于迁移学习,可以在保留部分预训练权重的同时调整模型结构。
方式 1 是直接保存模型的完整状态。
方式 2 仅保存模型的权重,便于灵活应用。
根据实际需求(是否需要调整模型结构)选择合适的保存方式。
torch.load以及torch.save的讲解
torch.load
和 torch.save
是 PyTorch 中用于保存和加载模型、张量和其他对象的重要函数。它们在深度学习项目中经常用于保存训练好的模型权重、模型结构、训练状态等。以下是它们的详细原理和使用方法:
torch.save
函数原理torch.save
用于将 Python 对象(如模型、张量、字典等)保存到文件中。它利用 Python 的序列化机制将对象的状态信息转化为可存储的字节流格式,并写入到文件中。
功能:保存 PyTorch 对象(如模型、张量、字典等)到文件中。
原理:
使用 PyTorch 的自定义序列化机制,将对象的状态信息(如张量的数据、模型的权重、优化器的状态等)序列化为字节流。
支持多种存储后端,包括 PyTorch 自己的二进制格式、pickle
序列化、zip 文件等。
保存的对象可以是张量、模型的整个对象、或者模型的权重(state_dict
)。
保存整个模型:
Python复制
torch.save(model, path)
保存模型的结构和权重。
优点:保存的内容完整,可以直接加载整个模型。
缺点:模型结构依赖于保存时的模型定义,如果模型定义发生变化(如代码改动),加载时可能出错。
保存模型权重(state_dict
):
Python复制
torch.save(model.state_dict(), path)
仅保存模型的权重(参数)。
这样可以在加载时更灵活地定义模型结构(例如,用于迁移学习或模型迁移)。
优点:模型结构和权重分离,更适合迁移学习和模型微调。
缺点:需要手动加载到模型结构中。
以下是一个保存模型的示例:
Python复制
import torch
import torchvision
# 创建一个未预训练的 VGG16 模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存整个模型
torch.save(vgg16, 'vgg16_model.pth')
# 保存模型权重
torch.save(vgg16.state_dict(), 'vgg16_weights.pth')
torch.load
函数原理torch.load
用于从文件中加载之前保存的 PyTorch 对象。它通过解析文件中的字节流,恢复对象的状态信息。
功能:从文件中加载 PyTorch 对象。
原理:
使用与 torch.save
相同的序列化机制,解析文件中的字节流,将其反序列化为 Python 对象。
自动检测文件的存储后端(如 pickle
、zip 文件等)并选择合适的解析方式。
加载整个模型:
Python复制
model = torch.load(path)
直接加载整个模型对象,包括模型结构和权重。
注意:需要确保在当前环境中定义了相同的模型类(如 torchvision.models.vgg16
),否则会报错。
加载模型权重(state_dict
):
Python复制
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load(path))
需要先定义模型结构,然后加载权重到模型中。
这样可以更灵活地加载权重到不同的模型结构中(例如,迁移学习)。
以下是一个加载模型的示例:
Python复制
import torch
import torchvision
# 加载整个模型
model = torch.load('vgg16_model.pth')
# 加载模型权重
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_weights.pth'))
文件路径和权限:
确保保存和加载文件时路径正确,且有权限读写文件。
兼容性:
如果使用 torch.save
保存整个模型,加载时需要保证 PyTorch 版本和模型定义兼容。
使用 state_dict
保存和加载可以更好地兼容不同版本的 PyTorch 和模型定义。
使用 GPU 和 CPU 的转换:
如果模型在 GPU 上保存,但在 CPU 上加载(反之亦然),可以在加载时指定设备:
Python复制
# 将模型加载到 CPU
model = torch.load(path, map_location=torch.device('cpu'))
# 将模型加载到 GPU
model = torch.load(path, map_location=torch.device('cuda:0'))
分布式训练的保存和加载:
在分布式训练中,state_dict
会包含 module.
前缀。加载时需要移除或保留前缀,以匹配模型结构。
torch.save
和 torch.load
是 PyTorch 中用于保存和加载模型的核心函数。
保存和加载的方式取决于应用场景(如迁移学习、微调、快速加载等)。
推荐使用 state_dict
方式保存和加载模型权重,以获得更好的灵活性和兼容性。