深度学习笔记12-mnist手写数字识别(Pytorch)

  •   本文为365天深度学习训练营中的学习记录博客
  • 原作者:K同学啊

目录

一、前期准备

1.导入数据

2.数据可视化

二、构建简单的CNN网络

三、训练模型

1.设置超参数

2.编写训练函数

3.编写测试函数

4.正式训练

 ​编辑 四、结果可视化

 ​编辑 

五、总结

1.Pytorch自带数据库:torchvision.dataset

 2.Pytorch自带数据加载器:torch.utils.data.DataLoader

3.Pytorch中进行参数更新时所需流程:


一、前期准备

1.导入数据

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
train_ds=torchvision.datasets.MNIST('data',train=True,
                                    download=True,
                                    transform=torchvision.transform.ToTensor())
train_ds=torchvision.datasets.MNIST('data',train=False,
                                    download=True,
                                    transform=torchvision.transform.ToTensor())
batch_size=32
train_dl=torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True)
test_dl=toch.utils.data.DataLoader(test_ds,batch_size=batch_size)
img,labels=next(iter(train_dl))#创建一个迭代器,用于遍历训练数据加载器中的数据,并从迭代器中取出第一个批次的数据
imgs.shape

2.数据可视化

import numpy as np
 # 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)
plt.figure(figsize=(20, 5)) 
for i,imgs in enumerate(imgs[:20])
    # 维度缩减
    npimg = np.squeeze(imgs.numpy()) #squeeze()函数的功能是从矩阵shape中,去掉维度为1的。
    # 将整个figure分成2行10列,绘制第i+1个子图。
    plt.subplot(2, 10, i+1)
    plt.imshow(npimg, cmap=plt.cm.binary)
    plt.axis('off')

二、构建简单的CNN网络

import torch.nn.functional as F
num_classes = 10  # 图片的类别数
class Model(nn.Module):
     def __init__(self):
        super().__init__()
         # 特征提取网络
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # 第一层卷积,卷积核大小为3*3
        self.pool1 = nn.MaxPool2d(2)                  # 设置池化层,池化核大小为2*2
        self.conv2 = nn.Conv2d(3

你可能感兴趣的:(深度学习,深度学习,笔记,pytorch)