<从零开始pytorch>:04-torch卷积神经网络搭建(巨详细总结)

构建卷积神经网络

  • 卷积神经网络中的输入层与传统的神经网络有些区别,需要重新设计,训练模块基本一致

1.读取数据

1.dataset --> DataLoader构建batch数据
    - 既然是构建batch数据,故需要batch_size参数

2.搭建模型

 - 1.是个类
 - 2.定义__init__()函数
     - 1.继承父类
     - 2.定义用到的层 self.XXX = XXX 格式
 - 3.forward(self,x)函数中调用__init__()函数中定义好的层
     - 这个x指的是输入的数据

3.用数据和定义好的模型进行训练

- 0.训练需要实例化定义好的模型
- 1.获取到加载好的数据
- 2.定义好的模型中传入数据,获得结果
- 3.用得到的结果和真实值比较,可以得到损失值和准确率
- 4.反向传播,优化器进行梯度更新
# 导包
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

1.读取数据

  • 分别构建训练集和测试集(验证集)
  • 使用DataLoader来迭代获取数据
# 定义超参数
input_size = 28 # 图像大小:28 * 28
num_classes = 10 # 标签的种类
num_epochs = 3 # 迭代的次数
batch_size = 64 # 每个批次的大小,即每64章图片一块进行一次训练

# 加载训练集
train_dataset = datasets.MNIST(
                                root='./data',
                                train=True,
                                transform=transforms.ToTensor(),
                                download=True
                                )
# 记载测试集
test_dataset = datasets.MNIST(root='./data',
                             train=False,
                             transform=transforms.ToTensor())
# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         )

2.搭建卷积神经网络模型

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        
        self.conv1 = nn.Sequential(    # 输入大小 (1,28,28)
            nn.Conv2d(
                in_channels=1,         # 说明是灰度图
                out_channels=16,       # 要得到多少个特征图
                kernel_size=5,         # 卷积核的大小
                stride=1,              # 步长
                padding=2),            # 边缘填充的大小
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2) # 池化操作 (2 * 2) 输出结果为: (16,14,14)
            
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(2))       # 输出 (32, 7, 7)
        self.out = nn.Linear(32*7*7, 10) # 全连接输入分类
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1) # flatten操作,结果为 (batch_size, 32*7*7)
        output = self.out(x)
        return output

3.定义准确率作为评估标准

def accuracy(predictions,labels):
    pred = torch.max(predictions.data,1)[1]
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights,len(labels)

4.训练网络模型

# 实例化
net = CNN()
# 选择损失函数
criterion = nn.CrossEntropyLoss()
# 选择优化器 
optimizer = optim.Adam(net.parameters(),lr=0.001) # 定义优化器,采用随机梯度下降算法

# 开始进行训练
for epoch in range(num_epochs):
    train_right = [] # 保存当前epoch的结果,和之前定义一个保存loss的是一个道理
    
    for batch_idx,(data,target) in enumerate(train_loader):
        net.train()
        output = net(data)
        loss = criterion(output,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() # 优化器调用step(),不是loss
        right = accuracy(output,target)
        train_right.append(right)
        
        if batch_idx % 100 == 0:
            net.eval()
            val_right = []
            for (data,target) in test_loader:
                output = net(data)
                right = accuracy(output,target)
                val_right.append(right)
            
            # 准确率的计算
            train_rate = (sum([tup[0] for tup in train_right]),sum(tup[1] for tup in train_right))
            val_rate = (sum([tup[0] for tup in val_right]),sum(tup[1] for tup in val_right))
            
            print('当前epoch:{}[{}/{}({:.0f}%)]\t 损失:{:.6f}\t 训练集准确率:{:.2f}%\t 测试集准确率:{:.2f}%'.format(
                epoch, batch_idx*batch_size, len(train_loader.dataset),
                100.0 * batch_idx / len(train_loader),
                loss.data,
                100.0 * train_rate[0].numpy() / train_rate[1],
                100.0 * val_rate[0].numpy() / val_rate[1]
                ))
当前epoch:0[0/60000(0%)]	 损失:2.301737	 训练集准确率:7.81%	 测试集准确率:13.38%
当前epoch:0[6400/60000(11%)]	 损失:0.340523	 训练集准确率:76.83%	 测试集准确率:92.94%
当前epoch:0[12800/60000(21%)]	 损失:0.183014	 训练集准确率:85.18%	 测试集准确率:94.99%
当前epoch:0[19200/60000(32%)]	 损失:0.292585	 训练集准确率:88.67%	 测试集准确率:95.91%
当前epoch:0[25600/60000(43%)]	 损失:0.125986	 训练集准确率:90.60%	 测试集准确率:97.23%
当前epoch:0[32000/60000(53%)]	 损失:0.148547	 训练集准确率:91.87%	 测试集准确率:97.43%
当前epoch:0[38400/60000(64%)]	 损失:0.049122	 训练集准确率:92.73%	 测试集准确率:98.09%
当前epoch:0[44800/60000(75%)]	 损失:0.111927	 训练集准确率:93.38%	 测试集准确率:97.85%
当前epoch:0[51200/60000(85%)]	 损失:0.015015	 训练集准确率:93.89%	 测试集准确率:98.07%
当前epoch:0[57600/60000(96%)]	 损失:0.009975	 训练集准确率:94.33%	 测试集准确率:98.03%
当前epoch:1[0/60000(0%)]	 损失:0.060913	 训练集准确率:96.88%	 测试集准确率:98.29%
当前epoch:1[6400/60000(11%)]	 损失:0.117949	 训练集准确率:98.08%	 测试集准确率:98.25%
当前epoch:1[12800/60000(21%)]	 损失:0.037740	 训练集准确率:98.02%	 测试集准确率:98.35%
当前epoch:1[19200/60000(32%)]	 损失:0.316654	 训练集准确率:98.12%	 测试集准确率:98.46%
当前epoch:1[25600/60000(43%)]	 损失:0.032529	 训练集准确率:98.23%	 测试集准确率:98.41%
当前epoch:1[32000/60000(53%)]	 损失:0.111036	 训练集准确率:98.23%	 测试集准确率:98.53%
当前epoch:1[38400/60000(64%)]	 损失:0.117095	 训练集准确率:98.25%	 测试集准确率:98.54%
当前epoch:1[44800/60000(75%)]	 损失:0.022684	 训练集准确率:98.23%	 测试集准确率:98.79%
当前epoch:1[51200/60000(85%)]	 损失:0.004846	 训练集准确率:98.25%	 测试集准确率:98.48%
当前epoch:1[57600/60000(96%)]	 损失:0.054813	 训练集准确率:98.28%	 测试集准确率:98.60%
当前epoch:2[0/60000(0%)]	 损失:0.014723	 训练集准确率:100.00%	 测试集准确率:98.63%
当前epoch:2[6400/60000(11%)]	 损失:0.067698	 训练集准确率:98.87%	 测试集准确率:98.83%
当前epoch:2[12800/60000(21%)]	 损失:0.043804	 训练集准确率:98.73%	 测试集准确率:98.90%
当前epoch:2[19200/60000(32%)]	 损失:0.033841	 训练集准确率:98.72%	 测试集准确率:98.52%
当前epoch:2[25600/60000(43%)]	 损失:0.026768	 训练集准确率:98.70%	 测试集准确率:98.73%
当前epoch:2[32000/60000(53%)]	 损失:0.049408	 训练集准确率:98.66%	 测试集准确率:98.86%
当前epoch:2[38400/60000(64%)]	 损失:0.009131	 训练集准确率:98.69%	 测试集准确率:98.86%
当前epoch:2[44800/60000(75%)]	 损失:0.035780	 训练集准确率:98.68%	 测试集准确率:98.76%
当前epoch:2[51200/60000(85%)]	 损失:0.010556	 训练集准确率:98.70%	 测试集准确率:98.92%
当前epoch:2[57600/60000(96%)]	 损失:0.044480	 训练集准确率:98.72%	 测试集准确率:98.89%

5总结

1.torch.size()

2.torch.view()

3.DataLoader类型的数据展示

4.data_set数据展示

  1. 定义一个Tensor演示size()方法和view()方法
x_test = torch.zeros(5,3,4)
x_test.shape

torch.Size([5, 3, 4])
x_test.size(0)
5
# x_test = x_test.view(x_test.size(0), -1)  # (5,12) -- 可以推断出是保留哪个维度作为第0维
x_test = x_test.view(x_test.size(1), -1) # (3,20)
x_test.shape
torch.Size([3, 20])
x_test.data # torch.data只获取数据
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])

2.DataLoader类型的数据展示的方法

  • 打印看下数据是什么样子的,可以看到data是batch_size个三维的数组,表示图像
  • target是标签值,表示图像的内容
train_loader # 这样不能看到数据

for batch_idx, (data, target) in enumerate(train_loader): 
    print(data,target)
    print(target.shape,data.shape)
    break
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],

        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],

        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],

        ...,

        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],

        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],

        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([7, 6, 5, 6, 5, 9, 7, 3, 4, 8, 0, 0, 5, 9, 3, 9, 1, 0, 0, 6, 2, 0, 0, 1,
        2, 5, 1, 4, 2, 6, 7, 4, 0, 7, 8, 6, 1, 4, 0, 5, 3, 6, 9, 5, 3, 8, 7, 2,
        6, 0, 2, 7, 4, 8, 7, 3, 4, 4, 1, 4, 2, 8, 0, 2])
torch.Size([64]) torch.Size([64, 1, 28, 28])

  • 可以看出用enumerate(数据集)的方式可以遍历显示数据集中的数据

# 看下dataset是什么
test_dataset
Dataset MNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: ToTensor()
for batch_idx, (data, target) in enumerate(train_dataset): 
    print(data,target)
#     print(target.shape,data.shape)
    break
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
          0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
          0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
          0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
          0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
          0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
          0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
          0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
          0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
          0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
          0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
          0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
          0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
          0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
          0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
          0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
          0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
          0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
          0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]]) 5

你可能感兴趣的:(pytorch学习,机器学习,卷积,pytorch,神经网络)