如何导入MNIST数据集

问题

当我使用github上别人的代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

结果会报错,原因就是read_data_sets这个函数会检查“/tmp/data/”这个目录下是否有其所需的四个文件,显然我们第一次运行的时候没有,它就会去它所写的地址去下载,但是国内可能因为某些问题就会出现网络错误。

解决办法

解决办法很简单,我们可以去lecnn大佬的主页将这四个文件下载下来,如闻如图所示:
如何导入MNIST数据集_第1张图片
我们将去下完之后放在某个文件下(不用解压,直接将四个压缩包放在那个文件下),之后将read_data_sets函数的路径改成刚才放文件的路径就可以了。
如何导入MNIST数据集_第2张图片
然后read_data_sets函数发现这个文件夹有这四个文件就不会去下载,结果就正常了:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./MNIST/",one_hot=True)
print (mnist.train.images.shape)
#(55000, 784)

你可能感兴趣的:(Python,Deep,Learning)