生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。(参考百度百科)
简单来说,GAN通过生成器和判别器的相互博弈,首先训练判别器辨别图片真伪,然后给以生成器真图片的数据不断训练,从而使得生成器可以达到以假乱真的效果,来骗过判别器;通过二者的反复博弈和训练后,生成器最终输出与真实值效果相近的图片。
原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
首先在构造函数内搭建我们GAN的整体框架,这里看不懂没关系,我会在训练模块详细讲解模型的训练过程。
class GAN():
def __init__(self):
# --------------------------------- #
# 行28,列28,也就是mnist的shape
# --------------------------------- #
self.img_rows = 28
self.img_cols = 28
self.channels = 1
# 28,28,1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# adam优化器
optimizer = Adam(0.0002, 0.5)
# 判别器
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# 生成器
self.generator = self.build_generator()
gan_input = Input(shape=(self.latent_dim,))
img = self.generator(gan_input)
# 在训练generate的时候不训练discriminator
self.discriminator.trainable = False
# 对生成的假图片进行预测
validity = self.discriminator(img)
self.combined = Model(gan_input, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
输入:
输入维度自拟为100长度的随机噪声序列。
生成器采用3个全连接的隐层:
激活函数:α为0.2的LeakyRelu激活函数
批标准化:设置移动均值和方差动量为0.8
输出:
输出维度与图片维度大小保持一致
def build_generator(self):
# --------------------------------- #
# 生成器,输入一串随机数字
# --------------------------------- #
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
# np.prod计算所有元素的乘积,保证生成输出的维度与图片维度一致
model.add(Reshape(self.img_shape))
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
输入:
大小为(batch, 28, 28)的手写数据集图片并进行铺平操作。
隐层:
2个全连接的隐层,激活函数为LeakyRelu。
输出:
采用sigmoid函数进行激活,输出判别单值
def build_discriminator(self):
# ----------------------------------- #
# 评价器,对输入进来的图片进行评价
# ----------------------------------- #
model = Sequential()
# 输入一张图片
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
# 判断真伪
model.add(Dense(1, activation='sigmoid'))
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
训练过程是GAN的重要思想,下面我们着重细致的讲解一下GAN是如何训练的吧。
将数据归一化和创建好真假标签后,我们开始进行每次epoch的训练
1. 首先对判别器进行训练:
这时候输入值有2种:
一是随机噪声输入生成器产生的假图片和标签0(第一次epoch时,生成器中可训练参数只是初始化的参数值);
二是真正的图片和标签1。对这两种输入值分别经过二元交叉熵函数(binary crossentropy)计算损失值求出d_loss_real和d_loss_fake,最终判别器D的损失值为二者的平均值d_loss = 0.5 * d_loss_real + 0.5 * d_loss_fake
2. 然后训练生成器:
训练生成器的时候,停止对判别器的训练,设置self.discriminator.trainable=False,为生成器输入随机噪声,将生成器产生的图片打上真实值标签1送入在这次epoch中已经训练好的判别器中,继续通过二元交叉熵函数(binary crossentropy)做loss进行生成器的训练,这样的话生成器就会向着真实值靠近。
训练完每次epoch后生成器性能就会增强,然后回到我们的第一步继续训练判别器,再训练生成器,在这样不断地相互博弈和促进下,最终生成器会达到以假乱真的效果。
def train(self, epochs, batch_size=128, sample_interval=50):
# 获得数据
(X_train, _), (_, _) = mnist.load_data()
# 进行数据归一化到[-1, 1]
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# 创建标签
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# --------------------------- #
# 随机选取batch_size个图片
# 对discriminator进行训练
# --------------------------- #
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# train_on_batch返回loss(若模型编译时赋予metrics=['acc']则还返回acc)
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# --------------------------- #
# 训练generator
# --------------------------- #
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
g_loss = self.combined.train_on_batch(noise, valid)
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
if epoch % sample_interval == 0:
self.sample_images(epoch)
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
import os
import numpy as np
class GAN():
def __init__(self):
# --------------------------------- #
# 行28,列28,也就是mnist的shape
# --------------------------------- #
self.img_rows = 28
self.img_cols = 28
self.channels = 1
# 28,28,1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# adam优化器
optimizer = Adam(0.0002, 0.5)
# 判别器
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# 生成器
self.generator = self.build_generator()
gan_input = Input(shape=(self.latent_dim,))
img = self.generator(gan_input)
# 在训练generate的时候不训练discriminator
self.discriminator.trainable = False
# 对生成的假图片进行预测
validity = self.discriminator(img)
self.combined = Model(gan_input, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
# --------------------------------- #
# 生成器,输入一串随机数字
# --------------------------------- #
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
# np.prod计算所有元素的乘积,保证生成输出的维度与图片维度一致
model.add(Reshape(self.img_shape))
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
# ----------------------------------- #
# 评价器,对输入进来的图片进行评价
# ----------------------------------- #
model = Sequential()
# 输入一张图片
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
# 判断真伪
model.add(Dense(1, activation='sigmoid'))
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# 获得数据
(X_train, _), (_, _) = mnist.load_data()
# 进行数据归一化到[-1, 1]
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# 创建标签
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# --------------------------- #
# 随机选取batch_size个图片
# 对discriminator进行训练
# --------------------------- #
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# train_on_batch返回loss(若模型编译时赋予metrics=['acc']则还返回acc)
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# --------------------------- #
# 训练generator
# --------------------------- #
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
g_loss = self.combined.train_on_batch(noise, valid)
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig("gan_mnist/%d.png" % epoch)
plt.close()
if __name__ == '__main__':
if not os.path.exists("./gan_mnist"):
os.makedirs("./gan_mnist")
gan = GAN()
gan.train(epochs=30000, batch_size=256, sample_interval=200)