pytorch-模型的加载,迁移学习,保存

这段代码的主要目的是加载和修改 VGG16 模型,并打印模型的结构。以下是代码的详细解释:

1. 导入必要的库

Python复制

import torchvision
import torch
  • torchvision 是 PyTorch 的计算机视觉扩展库,提供了常用的预训练模型和数据集。

  • torch 是 PyTorch 核心库,用于构建和训练深度学习模型。

2. 加载预训练和未预训练的 VGG16 模型

Python复制

vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)
  • vgg16_true:加载带有预训练权重的 VGG16 模型。预训练权重通常是在大规模数据集(如 ImageNet)上训练得到的。

  • vgg16_false:加载不带预训练权重的 VGG16 模型,初始权重是随机初始化的。

3. 打印预训练的 VGG16 模型结构

Python复制

print(vgg16_true)
  • 这行代码会输出 VGG16 模型的结构,包括卷积层(features)和全连接层(classifier)的详细信息。

4. 加载 CIFAR-10 数据集

Python复制

train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(), download=True)
  • torchvision.datasets.CIFAR10:加载 CIFAR-10 数据集。

    • "../data":数据集存储的路径。

    • train=True:加载训练集。

    • transform=torchvision.transforms.ToTensor():将图像数据转换为 PyTorch 张量。

    • download=True:如果数据集不存在,会自动下载。

5. 修改 VGG16 模型的分类器部分

Python复制

vgg16_true.classifier.add_module('add_linear', torch.nn.Linear(1000, 10))
  • vgg16_true.classifier:VGG16 模型的全连接层部分。

  • add_module('add_linear', torch.nn.Linear(1000, 10)):向分类器部分添加一个新的线性层。

    • 'add_linear':新模块的名称。

    • torch.nn.Linear(1000, 10):定义一个线性层,输入维度为 1000(对应原分类器的输出维度),输出维度为 10(对应 CIFAR-10 的 10 个类别)。

6. 再次打印修改后的 VGG16 模型结构

Python复制

print(vgg16_true)
  • 这行代码会输出修改后的 VGG16 模型的结构,可以看到新增的 add_linear 模块。

总结

  • 代码的主要目的是加载预训练的 VGG16 模型,并修改其分类器部分以适配 CIFAR-10 数据集。

  • 修改后的模型可以直接用于 CIFAR-10 的分类任务,而无需重新训练整个 VGG16 模型。

这段代码展示了两种不同的模型加载方式,并打印了模型的结构。以下是代码的详细解释:

1. 加载模型的方式1

Python复制

model = torch.load("vgg16_method1.pth")
  • 使用 torch.load 加载保存的模型文件 "vgg16_method1.pth"

  • 这种方式加载的模型是完整的模型对象,包括模型结构和权重。

  • 通常,这种保存方式是通过 torch.save(model, path) 保存的,保存了整个模型。

2. 加载模型的方式2

Python复制

vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
  • 先实例化一个未预训练的 VGG16 模型。

  • 使用 load_state_dict 方法加载保存的模型权重文件 "vgg16_method2.pth"

  • 这种方式加载的模型是通过 torch.save(model.state_dict(), path) 保存的模型权重。

3. 打印模型结构

Python复制

print(model)
print(vgg16)
  • 分别打印两种加载方式加载的模型结构。

  • 预期输出是两个模型的详细结构,包括各层的参数和连接方式。

区别与注意事项

  1. 保存方式不同

    • torch.save(model, path) 保存的是完整的模型对象(包括结构和权重)。

    • torch.save(model.state_dict(), path) 保存的是模型的权重字典。

  2. 加载方式不同

    • 方式1:直接加载完整的模型对象。

    • 方式2:需要先实例化模型,再加载权重。

  3. 适用场景

    • 如果需要迁移学习或微调模型,通常使用方式2(加载权重),因为这样可以灵活修改模型结构。

    • 如果模型结构固定,可以直接使用方式1加载整个模型。

运行结果示例

  • print(model):输出通过方式1加载的完整模型结构。

  • print(vgg16):输出通过方式2加载的模型结构,预期与方式1的结构一致(假设权重一致)。

总结

  • 方式1 适用于快速加载和使用完整的模型对象。

  • 方式2 适用于需要灵活修改模型结构或进行迁移学习的场景。

这段代码展示了如何使用 PyTorch 的 torch.save — PyTorch 2.6 documentation 函数以两种不同的方式保存模型。以下是代码的详细解释:

主线:代码逻辑

  1. 导入必要的库

    Python复制

    import torch
    import torchvision
    • torch: PyTorch 核心库,用于构建和训练深度学习模型。

    • torchvision: PyTorch 的计算机视觉扩展库,提供了常用的预训练模型。

  2. 加载未预训练的 VGG16 模型

    Python复制

    vgg16 = torchvision.models.vgg16(pretrained=False)
    • torchvision.models.vgg16: 加载 VGG16 模型。

    • pretrained=False: 表示不加载预训练权重,模型的权重是随机初始化的。

  3. 保存方式 1:模型结构 + 参数

    Python复制

    torch.save(vgg16, "vgg16_method1.pth")
    • torch.save(vgg16, "vgg16_method1.pth"): 直接保存整个模型对象。

      • 包括模型的结构(例如,层的设计)和权重(参数)。

      • 这种方式适合保存完整的模型,便于直接加载和使用。

  4. 保存方式 2:只保存模型参数(权重)

    Python复制

    torch.save(vgg16.state_dict(), "vgg16_method2.pth")
    • vgg16.state_dict(): 获得模型的权重字典(仅保存参数)。

    • torch.save(...): 保存权重字典到文件。

    • 这种方式更适合迁移学习或模型微调时使用,因为可以灵活地将权重加载到不同的模型结构中。

补充知识:两种保存方式的区别

特性 方式 1 方式 2
保存内容 模型结构 + 参数 仅参数(权重)
文件大小 较大(保存结构信息) 较小(仅保存参数)
适用场景 快速导出和加载完整模型 迁移学习、模型微调
局限性 模型结构依赖原环境 加载时需预先定义模型结构

载入模型

  • 方式 1

    Python复制

    model = torch.load("vgg16_method1.pth")
    • 直接加载整个模型对象。

  • 方式 2

    Python复制

    vgg16 = torchvision.models.vgg16(pretrained=False)
    vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
    • 需要先定义模型结构,再加载权重。

实际应用场景

  • 方式 1:适用于快速保存和分享完整的模型,不涉及模型结构调整。

  • 方式 2:适用于迁移学习,可以在保留部分预训练权重的同时调整模型结构。

总结

  • 方式 1 是直接保存模型的完整状态。

  • 方式 2 仅保存模型的权重,便于灵活应用。

  • 根据实际需求(是否需要调整模型结构)选择合适的保存方式。

torch.load以及torch.save的讲解

torch.loadtorch.save 是 PyTorch 中用于保存和加载模型、张量和其他对象的重要函数。它们在深度学习项目中经常用于保存训练好的模型权重、模型结构、训练状态等。以下是它们的详细原理和使用方法:

一、torch.save 函数原理

torch.save 用于将 Python 对象(如模型、张量、字典等)保存到文件中。它利用 Python 的序列化机制将对象的状态信息转化为可存储的字节流格式,并写入到文件中。

1. 功能与原理
  • 功能:保存 PyTorch 对象(如模型、张量、字典等)到文件中。

  • 原理

    • 使用 PyTorch 的自定义序列化机制,将对象的状态信息(如张量的数据、模型的权重、优化器的状态等)序列化为字节流。

    • 支持多种存储后端,包括 PyTorch 自己的二进制格式、pickle 序列化、zip 文件等。

    • 保存的对象可以是张量、模型的整个对象、或者模型的权重(state_dict)。

2. 常用的保存方式
  • 保存整个模型

    Python复制

    torch.save(model, path)
    • 保存模型的结构和权重。

    • 优点:保存的内容完整,可以直接加载整个模型。

    • 缺点:模型结构依赖于保存时的模型定义,如果模型定义发生变化(如代码改动),加载时可能出错。

  • 保存模型权重(state_dict

    Python复制

    torch.save(model.state_dict(), path)
    • 仅保存模型的权重(参数)。

    • 这样可以在加载时更灵活地定义模型结构(例如,用于迁移学习或模型迁移)。

    • 优点:模型结构和权重分离,更适合迁移学习和模型微调。

    • 缺点:需要手动加载到模型结构中。

3. 示例

以下是一个保存模型的示例:

Python复制

import torch
import torchvision

# 创建一个未预训练的 VGG16 模型
vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存整个模型
torch.save(vgg16, 'vgg16_model.pth')

# 保存模型权重
torch.save(vgg16.state_dict(), 'vgg16_weights.pth')

二、torch.load 函数原理

torch.load 用于从文件中加载之前保存的 PyTorch 对象。它通过解析文件中的字节流,恢复对象的状态信息。

1. 功能与原理
  • 功能:从文件中加载 PyTorch 对象。

  • 原理

    • 使用与 torch.save 相同的序列化机制,解析文件中的字节流,将其反序列化为 Python 对象。

    • 自动检测文件的存储后端(如 pickle、zip 文件等)并选择合适的解析方式。

2. 常用的加载方式
  • 加载整个模型

    Python复制

    model = torch.load(path)
    • 直接加载整个模型对象,包括模型结构和权重。

    • 注意:需要确保在当前环境中定义了相同的模型类(如 torchvision.models.vgg16),否则会报错。

  • 加载模型权重(state_dict

    Python复制

    model = torchvision.models.vgg16(pretrained=False)
    model.load_state_dict(torch.load(path))
    • 需要先定义模型结构,然后加载权重到模型中。

    • 这样可以更灵活地加载权重到不同的模型结构中(例如,迁移学习)。

3. 示例

以下是一个加载模型的示例:

Python复制

import torch
import torchvision

# 加载整个模型
model = torch.load('vgg16_model.pth')

# 加载模型权重
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_weights.pth'))

三、注意事项

  1. 文件路径和权限

    • 确保保存和加载文件时路径正确,且有权限读写文件。

  2. 兼容性

    • 如果使用 torch.save 保存整个模型,加载时需要保证 PyTorch 版本和模型定义兼容。

    • 使用 state_dict 保存和加载可以更好地兼容不同版本的 PyTorch 和模型定义。

  3. 使用 GPU 和 CPU 的转换

    • 如果模型在 GPU 上保存,但在 CPU 上加载(反之亦然),可以在加载时指定设备:

      Python复制

      # 将模型加载到 CPU
      model = torch.load(path, map_location=torch.device('cpu'))
      
      # 将模型加载到 GPU
      model = torch.load(path, map_location=torch.device('cuda:0'))
  4. 分布式训练的保存和加载

    • 在分布式训练中,state_dict 会包含 module. 前缀。加载时需要移除或保留前缀,以匹配模型结构。

四、总结

  • torch.savetorch.load 是 PyTorch 中用于保存和加载模型的核心函数。

  • 保存和加载的方式取决于应用场景(如迁移学习、微调、快速加载等)。

  • 推荐使用 state_dict 方式保存和加载模型权重,以获得更好的灵活性和兼容性。

你可能感兴趣的:(pytorch入门,pytorch,迁移学习,人工智能)