目录
一、前期准备
1.导入数据
2.数据可视化
二、构建简单的CNN网络
三、训练模型
1.设置超参数
2.编写训练函数
3.编写测试函数
4.正式训练
编辑 四、结果可视化
编辑
五、总结
1.Pytorch自带数据库:torchvision.dataset
2.Pytorch自带数据加载器:torch.utils.data.DataLoader
3.Pytorch中进行参数更新时所需流程:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
train_ds=torchvision.datasets.MNIST('data',train=True,
download=True,
transform=torchvision.transform.ToTensor())
train_ds=torchvision.datasets.MNIST('data',train=False,
download=True,
transform=torchvision.transform.ToTensor())
batch_size=32
train_dl=torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True)
test_dl=toch.utils.data.DataLoader(test_ds,batch_size=batch_size)
img,labels=next(iter(train_dl))#创建一个迭代器,用于遍历训练数据加载器中的数据,并从迭代器中取出第一个批次的数据
imgs.shape
import numpy as np
# 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)
plt.figure(figsize=(20, 5))
for i,imgs in enumerate(imgs[:20])
# 维度缩减
npimg = np.squeeze(imgs.numpy()) #squeeze()函数的功能是从矩阵shape中,去掉维度为1的。
# 将整个figure分成2行10列,绘制第i+1个子图。
plt.subplot(2, 10, i+1)
plt.imshow(npimg, cmap=plt.cm.binary)
plt.axis('off')
import torch.nn.functional as F
num_classes = 10 # 图片的类别数
class Model(nn.Module):
def __init__(self):
super().__init__()
# 特征提取网络
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # 第一层卷积,卷积核大小为3*3
self.pool1 = nn.MaxPool2d(2) # 设置池化层,池化核大小为2*2
self.conv2 = nn.Conv2d(3