GAN与WassersteinGAN代码keras分析

https://github.com/tdeboissiere/DeepLearningImplementations
使用mnist数据集。
train_GAN.py:


        gen_loss = 100
        disc_loss = 100

        # Start training
        print("Start training")
        for e in range(nb_epoch):
        # 每一回合
        # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            for X_real_batch in data_utils.gen_batch(X_real_train, batch_size):
# 从X_real_train随机抽取一个batch,循环次数,通过下面
# 达到batch_counter >= n_batch_per_epoch,break控
# 制
                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(X_real_batch,
                                                           generator_model,
                                                           batch_counter,
                                                           batch_size,
                                                           noise_dim,
                                                           noise_scale=noise_scale,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping=label_flipping)
# get_disc_batch为判别器生成一个batch数据,X_disc, 
# y_disc分别为数据和标签
# data_utils.get_disc_batch 与WGAN中不同,这里是
# batch_counter奇偶交替,数据为真实数据和生成的数据交替


                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
                # 判别器更新一次             
                # Create a batch to feed the generator model
                X_gen, y_gen = data_utils.get_gen_batch(batch_size, noise_dim, noise_scale=noise_scale)
                # 采样一个batch的噪声给生成器
                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen)
                # 更新一次生成器
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 1
                progbar.add(batch_size, values=[("D logloss", disc_loss),
                                                ("G logloss", gen_loss)])

                # Save images for visualization
                if batch_counter % 100 == 0:
                    data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                    batch_size, noise_dim, image_dim_ordering)

                if batch_counter >= n_batch_per_epoch:
                    break
#总结:1、给判别器一个batch的真实/生成数据,判别器更新;
      2、给生成器一个batch的真实数据,生成器更新;
      3、生成器和判别器更新次数相同。

train_WGAN.py

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
    # 编译模型,DCGAN_model是generator_model+discriminator_model.trainable = False
    # discriminator_model是discriminator_model.trainable = True

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    ################
    for e in range(nb_epoch): # nb_epoch:总共迭代多少回合
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch: # n_batch_per_epoch:每个回合包含的batch数量
        # 不断送入batch,直到每个回合的batch限定的数量
            if gen_iterations < 25 or gen_iterations % 500 == 0:
                disc_iterations = 100
            else:
                disc_iterations = kwargs["disc_iterations"]
# 最开始生成器迭代次数gen_iterations < 25时,生成器
# 迭代1次,判别器迭代100次,迭代1次,送入1个batch;
# gen_iterations >= 25后,生成器迭代1次,判别器迭代5
# 次(disc_iterations默认是5);
# 在gen_iterations整除500时,仍是生成器迭代1次,判别器# 迭代100次。



            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):
                # 每个batch迭代disc_iterations次
                # Clip discriminator weights 修剪判别器权重 
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
                    l.set_weights(weights)
                X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch,
                                                                    generator_model,
                                                                    batch_counter,
                                                                    batch_size,
                                                                    noise_dim,
                                                                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[0]))
 # 送入真实数据,标签是-1
                disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
 # 送入生成数据,标签是1 
 # loss是Wasserstein,给入真实样本,pred越高越好,
 # 然而 loss要减小的,而mean(y_true * y_pred)
 # =mean(-y_pred)正好符合减小;给入生成样本,pred越小
 # 越好,mean(y_true * y_pred)=mean(y_pred),也正好
 # 符合要求
            list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)               
# 对噪声采样,送入生成器
# 判别器更新disc_iterations次后,生成器更新1次
            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            gen_iterations += 1
            batch_counter += 1 
#总结:
#1、判别器的一次迭代中,先将所有层weight进行clip,
#  给入真实数据一个batch,判别器更新;再给入一个batch生#  成数据,判别器更新;
#2、判别器如此迭代多次,每一次均是送入新的batch。
#3、判别器更新多次后,生成器送入采样的一个batch噪声数据,#   进行一次更新;
#4、生成器更新一次后,一次回合中的batch_counter加1,
#   直到batch_counter到达每回合的batch数量,判别器送
#   入的不止这个batch的数量。

models_WGAN.py

def generator_upsampling(noise_dim, img_dim, bn_mode, model_name="generator_upsampling", dset="mnist"):
    # Noise input and reshaping
    x = Dense(f * start_dim * start_dim, input_dim=noise_dim)(gen_input)
    x = Reshape(reshape_shape)(x)
    x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
    x = Activation("relu")(x)
    # 100维噪声,全连接512*7*7,之后reshape成512个特征图形式7*7,再BN和relu
    # 以上与models_GAN没有区别
    # Upscaling blocks: Upsampling2D->Conv2D->ReLU->BN->Conv2D->ReLU
    for i in range(nb_upconv):
        x = UpSampling2D(size=(2, 2))(x)
        nb_filters = int(f / (2 ** (i + 1)))
        x = Convolution2D(nb_filters, 3, 3, border_mode="same", init=conv2D_init)(x)
        x = BatchNormalization(mode=bn_mode, axis=1)(x)
        x = Activation("relu")(x)
        x = Convolution2D(nb_filters, 3, 3, border_mode="same", init=conv2D_init)(x)
        x = Activation("relu")(x)
    # 两个上采样层,上采样完后,经历两次3*3卷积层和relu,这里卷积加了pad,并未改变特征图大小
    # Last Conv to get the output image
    x = Convolution2D(output_channels, 3, 3, name="gen_conv2d_final",                    border_mode="same", activation='tanh', init=conv2D_init)(x)
    # 最后层3*3卷积tanh,得到图像28*28
    generator_model = Model(input=[gen_input], output=[x], name=model_name)
    visualize_model(generator_model)
 # GAN相比,多了conv2D_init的初始化
    return generator_model
 #生成器总结:
 #1、输入噪声维数是100
 #2、进入全连接层,512*7*7,进而转成512个7*7特征图
 #3、上采样一次,变成512*14*14,256个3*3卷积核,特征
 #   图大小不变,再进行一次256个3*3卷积核卷积
 #4、上采样一次,变成256*28*28,128个3*3卷积核卷积,重
 #   复一次,均得到128*28*28
 #5、最后层卷积,一个特征图,变为1*28*28
 #整个网络结构:
 #100->7*7*512->上采样14*14*512->卷积(3*3)
 #14*14*256->卷积(3*3)14*14*256->上采样28*28*256
 #->卷积(3*3)28*28*128->卷积(3*3)28*28*128
 #->卷积(3*3)28*28*1
    # Get the list of number of conv filters
    # (first layer starts with 64), filters are subsequently doubled
    nb_conv = int(np.floor(np.log(min_s // 4) / np.log(2))) # 共几层卷积,计算公式 
def discriminator(img_dim, bn_mode, model_name="discriminator"):
    list_f = [64 * min(8, (2 ** i)) for i in range(nb_conv)]
# 第一层卷积核个数是64,后面逐层乘以2,最多乘到8
# list定义与GAN不同
    # First conv with 2x2 strides
    x = Convolution2D(list_f[0], 3, 3, subsample=(2, 2), name="disc_conv2d_1",
                     border_mode="same", bias=False, init=conv2D_init)(disc_input)
    x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
    x = LeakyReLU(0.2)(x)
    #第一层卷积计算,2x2 strides
    # Conv blocks: Conv2D(2x2 strides)->BN->LReLU
    for i, f in enumerate(list_f[1:]):
        name = "disc_conv2d_%s" % (i + 2)
        x = Convolution2D(f, 3, 3, subsample=(2, 2), name=name, border_mode="same", bias=False, init=conv2D_init)(x)
        x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
        x = LeakyReLU(0.2)(x)
    # 与GAN不同在于,conv2D_init初始化,以及bias=False的设置
    # 接下来几层卷积计算,均有batchNorm以及LeakyReLU

    # Last convolution
    #x = Convolution2D(1, 3, 3, name="last_conv", border_mode="same", bias=False, init=conv2D_init)(x)
    x = Convolution2D(1, 3, 3, name="last_conv", border_mode="same", bias=False)(x)
    # Average pooling
    x = GlobalAveragePooling2D()(x)
 # 最后层conv,特征图个数是1,进行了全局进行了平均,即得分,GAN最后层是softmax
 # 判别器总结:
 # 1、输入为28*28
 # 2、第一层卷积核个数为64,大小3*3,步长是subsample=(2, 2),padding是same,输出为14*14*64,注意:这里为tf里的conv2d的参数
 # 3、第二层卷积核个数128,大小3*3,步长是subsample=(2, 2),padding是same,输出为7*7*128
 # 4、最后一层卷积,卷积核个数为1,大小3*3,padding是same,输出为7*7*1,再进行全局平均,得到score,大小为1
 # 整体网络结构:三层卷积
 # 28*28->卷积(3*3,stride2*2)14*14*64->卷积
 # (3*3,stride2*2)7*7*128->卷积(3*3)7*7*1
 # ->全局池化1
    discriminator_model = Model(input=[disc_input], output=[x], name=model_name)
    visualize_model(discriminator_model)

    return discriminator_model
def wasserstein(y_true, y_pred):

    # return K.mean(y_true * y_pred) / K.mean(y_true)
    return K.mean(y_true * y_pred)
    #因为keras的所有loss都是最小化,对于真实样本maximize true_score = K.mean(y_pred),这里y_true=-1,就变成了求最小

你可能感兴趣的:(deep,learning)