基于Pytorch框架构建VGG-19模型

Pytorch

  • 一、训练模型
    • 1.导入资源包
    • 2.定义数据预处理
    • 3.读取数据
  • 二、定义VGG19模型
    • 1.定义自定义的 VGG19 模型
    • 运行结果:
  • 四、验证模型
    • 1. 定义验证过程
    • 2.用于训练模型并应用学习率调整策略的循环
    • 运行结果:
    • 3.保存模型的状态字典
  • 三、训练模型
    • 1. 定义训练函数
  • 五、创建 CustomVGG19 模型实例
    • 1. 导入资源包
    • 2.定义数据预处理
    • 4.创建 CustomVGG19 模型实例
    • 5.定义预测函数
    • 6.定义了一个可视化函数
    • 运行结果:

一、训练模型

1.导入资源包

import torch.utils.data: 导入了PyTorch的数据工具模块,这个模块提供了用于数据加载和处理的工具,如Dataset和DataLoader。
from torchvision import models: 导入了PyTorch的预训练模型模块,这个模块提供了多种预训练的型,如ResNet、VGG、AlexNet等,可以直接用于迁移学习或特征提取。

from sched import scheduler
import torch.optim as optim
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import  os
from torchvision import models

2.定义数据预处理

这些预处理操作的目的是为了增强模型的泛化能力,并确保模型在训练和验证时输入数据的格式一致。通过这些操作,模型能够接受不同尺寸、角度和方向的图像,从而提高其在实际应用中的表现。同时,归一化处理有助于稳定训练过程,加速模型收敛。,这些预处理操作的目的是为了增强模型的泛化能力,并确保模型在训练和验证时输入数据的格式一致。通过这些操作,模型能够接受不同尺寸、角度和方向的图像,从而提高其在实际应用中的表现。同时,归一化处理有助于稳定训练过程,加速模型收敛。

# 定义数据预处理
transform = {
   
'train': transforms.Compose([
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
}

3.读取数据

读取和准备图像数据集,以便用于训练和验证深度学习模型,这段代码设置了数据加载器,它们将在训练和验证过程中提供经过预处理的图像数据。这些数据加载器是PyTorch中用于批量加载数据并使其易于迭代的重要工具。

# 读取数据
dataset = './dataset'
train_directory = os.path.join(dataset, 'train')
valid_directory = os.path.join(dataset, 'val')

batch_size = 32
num_classes = 2  # 修改为您的分类数

data = {
   
'train': datasets.ImageFolder(root=train_directory, transform=transform['train']),
'val': datasets.ImageFolder(root=valid_directory, transform=transform['val'])
}

train_loader = DataLoader(data['train'], batch_size=batch_size, shuffle=True

你可能感兴趣的:(pytorch,人工智能,python,深度学习,机器学习,神经网络,cnn)