https://github.com/tdeboissiere/DeepLearningImplementations
使用mnist数据集。
train_GAN.py:
gen_loss = 100
disc_loss = 100
# Start training
print("Start training")
for e in range(nb_epoch):
# 每一回合
# Initialize progbar and batch counter
progbar = generic_utils.Progbar(epoch_size)
batch_counter = 1
start = time.time()
for X_real_batch in data_utils.gen_batch(X_real_train, batch_size):
# 从X_real_train随机抽取一个batch,循环次数,通过下面
# 达到batch_counter >= n_batch_per_epoch,break控
# 制
# Create a batch to feed the discriminator model
X_disc, y_disc = data_utils.get_disc_batch(X_real_batch,
generator_model,
batch_counter,
batch_size,
noise_dim,
noise_scale=noise_scale,
label_smoothing=label_smoothing,
label_flipping=label_flipping)
# get_disc_batch为判别器生成一个batch数据,X_disc,
# y_disc分别为数据和标签
# data_utils.get_disc_batch 与WGAN中不同,这里是
# batch_counter奇偶交替,数据为真实数据和生成的数据交替
# Update the discriminator
disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
# 判别器更新一次
# Create a batch to feed the generator model
X_gen, y_gen = data_utils.get_gen_batch(batch_size, noise_dim, noise_scale=noise_scale)
# 采样一个batch的噪声给生成器
# Freeze the discriminator
discriminator_model.trainable = False
gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen)
# 更新一次生成器
# Unfreeze the discriminator
discriminator_model.trainable = True
batch_counter += 1
progbar.add(batch_size, values=[("D logloss", disc_loss),
("G logloss", gen_loss)])
# Save images for visualization
if batch_counter % 100 == 0:
data_utils.plot_generated_batch(X_real_batch, generator_model,
batch_size, noise_dim, image_dim_ordering)
if batch_counter >= n_batch_per_epoch:
break
#总结:1、给判别器一个batch的真实/生成数据,判别器更新;
2、给生成器一个batch的真实数据,生成器更新;
3、生成器和判别器更新次数相同。
train_WGAN.py
############################
# Compile models
############################
generator_model.compile(loss='mse', optimizer=opt_G)
discriminator_model.trainable = False
DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
discriminator_model.trainable = True
discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
# 编译模型,DCGAN_model是generator_model+discriminator_model.trainable = False
# discriminator_model是discriminator_model.trainable = True
# Global iteration counter for generator updates
gen_iterations = 0
#################
# Start training
################
for e in range(nb_epoch): # nb_epoch:总共迭代多少回合
# Initialize progbar and batch counter
progbar = generic_utils.Progbar(epoch_size)
batch_counter = 1
start = time.time()
while batch_counter < n_batch_per_epoch: # n_batch_per_epoch:每个回合包含的batch数量
# 不断送入batch,直到每个回合的batch限定的数量
if gen_iterations < 25 or gen_iterations % 500 == 0:
disc_iterations = 100
else:
disc_iterations = kwargs["disc_iterations"]
# 最开始生成器迭代次数gen_iterations < 25时,生成器
# 迭代1次,判别器迭代100次,迭代1次,送入1个batch;
# gen_iterations >= 25后,生成器迭代1次,判别器迭代5
# 次(disc_iterations默认是5);
# 在gen_iterations整除500时,仍是生成器迭代1次,判别器# 迭代100次。
###################################
# 1) Train the critic / discriminator
###################################
list_disc_loss_real = []
list_disc_loss_gen = []
for disc_it in range(disc_iterations):
# 每个batch迭代disc_iterations次
# Clip discriminator weights 修剪判别器权重
for l in discriminator_model.layers:
weights = l.get_weights()
weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
l.set_weights(weights)
X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size))
# Create a batch to feed the discriminator model
X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch,
generator_model,
batch_counter,
batch_size,
noise_dim,
noise_scale=noise_scale)
# Update the discriminator
disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0]))
# 送入真实数据,标签是-1
disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
# 送入生成数据,标签是1
# loss是Wasserstein,给入真实样本,pred越高越好,
# 然而 loss要减小的,而mean(y_true * y_pred)
# =mean(-y_pred)正好符合减小;给入生成样本,pred越小
# 越好,mean(y_true * y_pred)=mean(y_pred),也正好
# 符合要求
list_disc_loss_real.append(disc_loss_real)
list_disc_loss_gen.append(disc_loss_gen)
# 对噪声采样,送入生成器
# 判别器更新disc_iterations次后,生成器更新1次
# Freeze the discriminator
discriminator_model.trainable = False
gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0]))
# Unfreeze the discriminator
discriminator_model.trainable = True
gen_iterations += 1
batch_counter += 1
#总结:
#1、判别器的一次迭代中,先将所有层weight进行clip,
# 给入真实数据一个batch,判别器更新;再给入一个batch生# 成数据,判别器更新;
#2、判别器如此迭代多次,每一次均是送入新的batch。
#3、判别器更新多次后,生成器送入采样的一个batch噪声数据,# 进行一次更新;
#4、生成器更新一次后,一次回合中的batch_counter加1,
# 直到batch_counter到达每回合的batch数量,判别器送
# 入的不止这个batch的数量。
models_WGAN.py
def generator_upsampling(noise_dim, img_dim, bn_mode, model_name="generator_upsampling", dset="mnist"):
# Noise input and reshaping
x = Dense(f * start_dim * start_dim, input_dim=noise_dim)(gen_input)
x = Reshape(reshape_shape)(x)
x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
x = Activation("relu")(x)
# 100维噪声,全连接512*7*7,之后reshape成512个特征图形式7*7,再BN和relu
# 以上与models_GAN没有区别
# Upscaling blocks: Upsampling2D->Conv2D->ReLU->BN->Conv2D->ReLU
for i in range(nb_upconv):
x = UpSampling2D(size=(2, 2))(x)
nb_filters = int(f / (2 ** (i + 1)))
x = Convolution2D(nb_filters, 3, 3, border_mode="same", init=conv2D_init)(x)
x = BatchNormalization(mode=bn_mode, axis=1)(x)
x = Activation("relu")(x)
x = Convolution2D(nb_filters, 3, 3, border_mode="same", init=conv2D_init)(x)
x = Activation("relu")(x)
# 两个上采样层,上采样完后,经历两次3*3卷积层和relu,这里卷积加了pad,并未改变特征图大小
# Last Conv to get the output image
x = Convolution2D(output_channels, 3, 3, name="gen_conv2d_final", border_mode="same", activation='tanh', init=conv2D_init)(x)
# 最后层3*3卷积tanh,得到图像28*28
generator_model = Model(input=[gen_input], output=[x], name=model_name)
visualize_model(generator_model)
# GAN相比,多了conv2D_init的初始化
return generator_model
#生成器总结:
#1、输入噪声维数是100
#2、进入全连接层,512*7*7,进而转成512个7*7特征图
#3、上采样一次,变成512*14*14,256个3*3卷积核,特征
# 图大小不变,再进行一次256个3*3卷积核卷积
#4、上采样一次,变成256*28*28,128个3*3卷积核卷积,重
# 复一次,均得到128*28*28
#5、最后层卷积,一个特征图,变为1*28*28
#整个网络结构:
#100->7*7*512->上采样14*14*512->卷积(3*3)
#14*14*256->卷积(3*3)14*14*256->上采样28*28*256
#->卷积(3*3)28*28*128->卷积(3*3)28*28*128
#->卷积(3*3)28*28*1
# Get the list of number of conv filters
# (first layer starts with 64), filters are subsequently doubled
nb_conv = int(np.floor(np.log(min_s // 4) / np.log(2))) # 共几层卷积,计算公式
def discriminator(img_dim, bn_mode, model_name="discriminator"):
list_f = [64 * min(8, (2 ** i)) for i in range(nb_conv)]
# 第一层卷积核个数是64,后面逐层乘以2,最多乘到8
# list定义与GAN不同
# First conv with 2x2 strides
x = Convolution2D(list_f[0], 3, 3, subsample=(2, 2), name="disc_conv2d_1",
border_mode="same", bias=False, init=conv2D_init)(disc_input)
x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
x = LeakyReLU(0.2)(x)
#第一层卷积计算,2x2 strides
# Conv blocks: Conv2D(2x2 strides)->BN->LReLU
for i, f in enumerate(list_f[1:]):
name = "disc_conv2d_%s" % (i + 2)
x = Convolution2D(f, 3, 3, subsample=(2, 2), name=name, border_mode="same", bias=False, init=conv2D_init)(x)
x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
x = LeakyReLU(0.2)(x)
# 与GAN不同在于,conv2D_init初始化,以及bias=False的设置
# 接下来几层卷积计算,均有batchNorm以及LeakyReLU
# Last convolution
#x = Convolution2D(1, 3, 3, name="last_conv", border_mode="same", bias=False, init=conv2D_init)(x)
x = Convolution2D(1, 3, 3, name="last_conv", border_mode="same", bias=False)(x)
# Average pooling
x = GlobalAveragePooling2D()(x)
# 最后层conv,特征图个数是1,进行了全局进行了平均,即得分,GAN最后层是softmax
# 判别器总结:
# 1、输入为28*28
# 2、第一层卷积核个数为64,大小3*3,步长是subsample=(2, 2),padding是same,输出为14*14*64,注意:这里为tf里的conv2d的参数
# 3、第二层卷积核个数128,大小3*3,步长是subsample=(2, 2),padding是same,输出为7*7*128
# 4、最后一层卷积,卷积核个数为1,大小3*3,padding是same,输出为7*7*1,再进行全局平均,得到score,大小为1
# 整体网络结构:三层卷积
# 28*28->卷积(3*3,stride2*2)14*14*64->卷积
# (3*3,stride2*2)7*7*128->卷积(3*3)7*7*1
# ->全局池化1
discriminator_model = Model(input=[disc_input], output=[x], name=model_name)
visualize_model(discriminator_model)
return discriminator_model
def wasserstein(y_true, y_pred):
# return K.mean(y_true * y_pred) / K.mean(y_true)
return K.mean(y_true * y_pred)
#因为keras的所有loss都是最小化,对于真实样本maximize true_score = K.mean(y_pred),这里y_true=-1,就变成了求最小