GAN实现MNIST

GAN是生成对抗网络,具体的原理这里就不详解了,该博文简要实现了GAN的搭建,在MNIST数据集上使用,亲测100%可用。

#导入相应的代码库
import tensorflow as tf
import numpy as np
import os
from matplotlib import pyplot as plt
#该部分是提取数据,假设MNIST数据已经在电脑里下载,下载位置为MNIST_data,trainX,testX均为数组
path="MNIST_data/"
f=open(os.path.join(path,"train-images.idx3-ubyte"))
loaded=np.fromfile(f,dtype=np.uint8)
trainX=loaded[16:].reshape((60000,-1)).astype(np.float)

f=open(os.path.join(path,"train-labels.idx1-ubyte"))
loaded=np.fromfile(file=f,dtype=np.uint8)
trainY=loaded[8:].reshape((60000,1)).astype(np.float)


f=open(os.path.join(path,"t10k-images.idx3-ubyte"))
loaded=np.fromfile(f,dtype=np.uint8)
testX=loaded[16:].reshape((10000,-1)).astype(np.float)

f=open(os.path.join(path,"t10k-labels.idx1-ubyte"))
loaded=np.fromfile(file=f,dtype=np.uint8)
testY=loaded[8:].reshape((10000,1)).astype(np.float)
#将训练集的和测试集的数据进行合并,以此来进行训练
X=np.concatenate((trainX,testX),axis=0)/255.0
Y=np.concatenate((trainY,testY),axis=0)
#显示图片,该篇代码没有涉及到模型保存部分,模型在每次迭代时会输出生成的图片
def show_image(batch_images,grid_size=(8,8)):
    batch_images=batch_images.reshape((BATCH_SIZE,28,28))
    merge_image=np.zeros((28*grid_size[0],28*grid_size[1]),dtype=np.uint8)
    for index,image in enumerate(batch_images):
        i=index//grid

你可能感兴趣的:(GAN实现MNIST)