1.dataset --> DataLoader构建batch数据
- 既然是构建batch数据,故需要batch_size参数
- 1.是个类
- 2.定义__init__()函数
- 1.继承父类
- 2.定义用到的层 self.XXX = XXX 格式
- 3.forward(self,x)函数中调用__init__()函数中定义好的层
- 这个x指的是输入的数据
- 0.训练需要实例化定义好的模型
- 1.获取到加载好的数据
- 2.定义好的模型中传入数据,获得结果
- 3.用得到的结果和真实值比较,可以得到损失值和准确率
- 4.反向传播,优化器进行梯度更新
# 导包
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# 定义超参数
input_size = 28 # 图像大小:28 * 28
num_classes = 10 # 标签的种类
num_epochs = 3 # 迭代的次数
batch_size = 64 # 每个批次的大小,即每64章图片一块进行一次训练
# 加载训练集
train_dataset = datasets.MNIST(
root='./data',
train=True,
transform=transforms.ToTensor(),
download=True
)
# 记载测试集
test_dataset = datasets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
)
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Sequential( # 输入大小 (1,28,28)
nn.Conv2d(
in_channels=1, # 说明是灰度图
out_channels=16, # 要得到多少个特征图
kernel_size=5, # 卷积核的大小
stride=1, # 步长
padding=2), # 边缘填充的大小
nn.ReLU(),
nn.MaxPool2d(kernel_size=2) # 池化操作 (2 * 2) 输出结果为: (16,14,14)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16,32,5,1,2),
nn.ReLU(),
nn.MaxPool2d(2)) # 输出 (32, 7, 7)
self.out = nn.Linear(32*7*7, 10) # 全连接输入分类
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten操作,结果为 (batch_size, 32*7*7)
output = self.out(x)
return output
def accuracy(predictions,labels):
pred = torch.max(predictions.data,1)[1]
rights = pred.eq(labels.data.view_as(pred)).sum()
return rights,len(labels)
# 实例化
net = CNN()
# 选择损失函数
criterion = nn.CrossEntropyLoss()
# 选择优化器
optimizer = optim.Adam(net.parameters(),lr=0.001) # 定义优化器,采用随机梯度下降算法
# 开始进行训练
for epoch in range(num_epochs):
train_right = [] # 保存当前epoch的结果,和之前定义一个保存loss的是一个道理
for batch_idx,(data,target) in enumerate(train_loader):
net.train()
output = net(data)
loss = criterion(output,target)
optimizer.zero_grad()
loss.backward()
optimizer.step() # 优化器调用step(),不是loss
right = accuracy(output,target)
train_right.append(right)
if batch_idx % 100 == 0:
net.eval()
val_right = []
for (data,target) in test_loader:
output = net(data)
right = accuracy(output,target)
val_right.append(right)
# 准确率的计算
train_rate = (sum([tup[0] for tup in train_right]),sum(tup[1] for tup in train_right))
val_rate = (sum([tup[0] for tup in val_right]),sum(tup[1] for tup in val_right))
print('当前epoch:{}[{}/{}({:.0f}%)]\t 损失:{:.6f}\t 训练集准确率:{:.2f}%\t 测试集准确率:{:.2f}%'.format(
epoch, batch_idx*batch_size, len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.data,
100.0 * train_rate[0].numpy() / train_rate[1],
100.0 * val_rate[0].numpy() / val_rate[1]
))
当前epoch:0[0/60000(0%)] 损失:2.301737 训练集准确率:7.81% 测试集准确率:13.38%
当前epoch:0[6400/60000(11%)] 损失:0.340523 训练集准确率:76.83% 测试集准确率:92.94%
当前epoch:0[12800/60000(21%)] 损失:0.183014 训练集准确率:85.18% 测试集准确率:94.99%
当前epoch:0[19200/60000(32%)] 损失:0.292585 训练集准确率:88.67% 测试集准确率:95.91%
当前epoch:0[25600/60000(43%)] 损失:0.125986 训练集准确率:90.60% 测试集准确率:97.23%
当前epoch:0[32000/60000(53%)] 损失:0.148547 训练集准确率:91.87% 测试集准确率:97.43%
当前epoch:0[38400/60000(64%)] 损失:0.049122 训练集准确率:92.73% 测试集准确率:98.09%
当前epoch:0[44800/60000(75%)] 损失:0.111927 训练集准确率:93.38% 测试集准确率:97.85%
当前epoch:0[51200/60000(85%)] 损失:0.015015 训练集准确率:93.89% 测试集准确率:98.07%
当前epoch:0[57600/60000(96%)] 损失:0.009975 训练集准确率:94.33% 测试集准确率:98.03%
当前epoch:1[0/60000(0%)] 损失:0.060913 训练集准确率:96.88% 测试集准确率:98.29%
当前epoch:1[6400/60000(11%)] 损失:0.117949 训练集准确率:98.08% 测试集准确率:98.25%
当前epoch:1[12800/60000(21%)] 损失:0.037740 训练集准确率:98.02% 测试集准确率:98.35%
当前epoch:1[19200/60000(32%)] 损失:0.316654 训练集准确率:98.12% 测试集准确率:98.46%
当前epoch:1[25600/60000(43%)] 损失:0.032529 训练集准确率:98.23% 测试集准确率:98.41%
当前epoch:1[32000/60000(53%)] 损失:0.111036 训练集准确率:98.23% 测试集准确率:98.53%
当前epoch:1[38400/60000(64%)] 损失:0.117095 训练集准确率:98.25% 测试集准确率:98.54%
当前epoch:1[44800/60000(75%)] 损失:0.022684 训练集准确率:98.23% 测试集准确率:98.79%
当前epoch:1[51200/60000(85%)] 损失:0.004846 训练集准确率:98.25% 测试集准确率:98.48%
当前epoch:1[57600/60000(96%)] 损失:0.054813 训练集准确率:98.28% 测试集准确率:98.60%
当前epoch:2[0/60000(0%)] 损失:0.014723 训练集准确率:100.00% 测试集准确率:98.63%
当前epoch:2[6400/60000(11%)] 损失:0.067698 训练集准确率:98.87% 测试集准确率:98.83%
当前epoch:2[12800/60000(21%)] 损失:0.043804 训练集准确率:98.73% 测试集准确率:98.90%
当前epoch:2[19200/60000(32%)] 损失:0.033841 训练集准确率:98.72% 测试集准确率:98.52%
当前epoch:2[25600/60000(43%)] 损失:0.026768 训练集准确率:98.70% 测试集准确率:98.73%
当前epoch:2[32000/60000(53%)] 损失:0.049408 训练集准确率:98.66% 测试集准确率:98.86%
当前epoch:2[38400/60000(64%)] 损失:0.009131 训练集准确率:98.69% 测试集准确率:98.86%
当前epoch:2[44800/60000(75%)] 损失:0.035780 训练集准确率:98.68% 测试集准确率:98.76%
当前epoch:2[51200/60000(85%)] 损失:0.010556 训练集准确率:98.70% 测试集准确率:98.92%
当前epoch:2[57600/60000(96%)] 损失:0.044480 训练集准确率:98.72% 测试集准确率:98.89%
x_test = torch.zeros(5,3,4)
x_test.shape
torch.Size([5, 3, 4])
x_test.size(0)
5
# x_test = x_test.view(x_test.size(0), -1) # (5,12) -- 可以推断出是保留哪个维度作为第0维
x_test = x_test.view(x_test.size(1), -1) # (3,20)
x_test.shape
torch.Size([3, 20])
x_test.data # torch.data只获取数据
tensor([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
2.DataLoader类型的数据展示的方法
train_loader # 这样不能看到数据
for batch_idx, (data, target) in enumerate(train_loader):
print(data,target)
print(target.shape,data.shape)
break
tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
...,
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]]) tensor([7, 6, 5, 6, 5, 9, 7, 3, 4, 8, 0, 0, 5, 9, 3, 9, 1, 0, 0, 6, 2, 0, 0, 1,
2, 5, 1, 4, 2, 6, 7, 4, 0, 7, 8, 6, 1, 4, 0, 5, 3, 6, 9, 5, 3, 8, 7, 2,
6, 0, 2, 7, 4, 8, 7, 3, 4, 4, 1, 4, 2, 8, 0, 2])
torch.Size([64]) torch.Size([64, 1, 28, 28])
# 看下dataset是什么
test_dataset
Dataset MNIST
Number of datapoints: 10000
Root location: ./data
Split: Test
StandardTransform
Transform: ToTensor()
for batch_idx, (data, target) in enumerate(train_dataset):
print(data,target)
# print(target.shape,data.shape)
break
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000]]]) 5