生成对抗网络做生成手写数字

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.mnist.load_data()
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype("float32")
train_images=(train_images-127.5)/127.5
noise_size=100
batch_size=32
datasets=tf.data.Dataset.from_tensor_slices(train_images)
datasets=datasets.shuffle(train_images.shape[0]).repeat().batch(batch_size)
generator_optimizer=tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer=tf.keras.optimizers.Adam(1e-4)
epochs=50
num_examples_to_show=8
show_data=tf.random.normal([num_examples_to_show,noise_size])

def generator_model():
	#input data is in shape of (batch_size,noise_data)
	model=keras.Sequential()
	model.add(layers.Dense(128,input_shape=(noise_size,),use_bias=False))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU())

	model.add(layers.Dense(28*28*1,use_bias=False,activation="tanh"))
	model.add(layers.BatchNormalization())

	model.add(layers.Reshape((28,28,1)))#output is a four dimensional tensor(batch_size,28,28,1)
	#Don't forget the real data has been reshape in (batch_size,28,28,1)
	return model

def discriminator_model():
	model=keras.Sequential()#除了第一维(因为第一维代表数量)将后面维度压平,即:(60000,28,28)——>(60000,784)
	model.add(layers.Flatten())
	model.add(layers.Dense(64,use_bias=False))
	model.add(layers.BatchNormalization())
	model.add(layers.LeakyReLU())

	model.add(layers.Dense(1))#(batch_size,1)
	return model

def get_loss(fake_output,real_output):
	#fake_output and real_output is the output result of discriminator model
	cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)
	D_real_loss=cross_entropy(tf.ones_like(real_output),real_output)
	D_fake_loss=cross_entropy(tf.zeros_like(fake_output),fake_output)
	D_loss=D_real_loss+D_fake_loss
	G_loss=cross_entropy(tf.ones_like(fake_output),fake_output)
	return D_loss,G_loss

generator=generator_model()
discriminator=discriminator_model()

def train_op():
	for real_data in datasets:
		#assert real_data.shape==(batch_size,28,28,1)
		noise_data=tf.random.normal([batch_size,noise_size])
		print(".",end="")
		with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:
			fake_data=generator(noise_data,training=True)
			real_output=discriminator(real_data,training=True)
			fake_output=discriminator(fake_data,training=True)
			dis_loss,gen_loss=get_loss(fake_output,real_output)

		gen_gradient=gen_tape.gradient(gen_loss,generator.trainable_variables)
		dis_gradient=dis_tape.gradient(dis_loss,discriminator.trainable_variables)

		generator_optimizer.apply_gradients(zip(gen_gradient,generator.trainable_variables))
		discriminator_optimizer.apply_gradients(zip(dis_gradient,discriminator.trainable_variables))

def generate_plot_image(generator,show_data):
	predict_images=generator(show_data,training=False)#(8,28,28,1)
	fig=plt.figure(figsize=(2,4))
	for i in range(num_examples_to_show):
		plt.subplot(2,4,i+1)
		plt.imshow((predict_images[i,:,:,0]+1)/2,cmap="gray")
		plt.axis("off")
	plt.show()

def train():
	for epoch in range(epochs):
		train_op()
		generate_plot_image(generator,show_data)

train()

你可能感兴趣的:(生成对抗网络做生成手写数字)