# -*- coding: utf-8 -*-
"""
Created on Sat Jul 18 12:27:15 2020
@author: 陨星落云
"""
#%%
from torchvision import datasets
import torch
#%% 下载数据并加载训练集
path2data = "./data"
train_data = datasets.MNIST(path2data,train=True,download=False)
#%% 抽取训练集数据与标签
x_train,y_train = train_data.data,train_data.targets
print("x_train:",x_train.shape)
print("y_train:",y_train.shape)
#%% 加载验证集
val_data = datasets.MNIST(path2data,train=False,download=False)
#%% 抽取验证集数据与标签
x_val,y_val = val_data.data,val_data.targets
print("x_val:",x_val.shape)
print("y_val:",y_val.shape)
#%% 在张量中增加一个维度
if len(x_train.shape)==3:
x_train = x_train.unsqueeze(1)
print(x_train.shape)
if len(x_val.shape)==3:
x_val = x_val.unsqueeze(1)
print(x_val.shape)
#%% 导入需要的包
from torchvision import utils
import matplotlib.pylab as plt
import numpy as np
#%% 显示图像函数
def show(img):
# tensor转numpy
npimg = img.numpy()
# 转H*W*C
npimg_tr = np.transpose(npimg,(1,2,0))
plt.imshow(npimg_tr,interpolation="nearest")
plt.show()
#%% 批量显示图像
# make a grid of 40 images, 8 images per row
x_grid = utils.make_grid(x_train[:40],nrow=8,padding=2)
print(x_grid.shape)
show(x_grid)
结果:
x_train: torch.Size([60000, 28, 28])
y_train: torch.Size([60000])
x_val: torch.Size([10000, 28, 28])
y_val: torch.Size([10000])
torch.Size([60000, 1, 28, 28])
torch.Size([10000, 1, 28, 28])
torch.Size([3, 152, 242])