cycleGAN可用于风格迁移,用来处理图像转换过程中数据unpair的问题。如下图所示,原本正常马可以变成斑马,同时保持背景不变。其本质是两个对称的GAN,构成一个环形网络。两个GAN共享两个生成器,并各自带一个判别器,即总共两个判别器和两个生成器。
Loss的计算
本质上还是Gan,所以弄清楚其损失函数的计算方法就可以了。
# 鉴别器的loss
def discriminator_loss(disc_of_real_output, disc_of_gen_output, lsgan=True):
if lsgan:
real_loss = keras.losses.mean_squared_error(disc_of_real_output, tf.ones_like(disc_of_real_output))
generated_loss = tf.reduce_mean(tf.square(disc_of_gen_output))
total_disc_loss = (real_loss + generated_loss) * 0.5 # 0.5 slows down rate that D learns compared to G
else: # Use vanilla GAN loss
raise NotImplementedError
real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(disc_of_real_output),
logits=disc_of_real_output)
generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(disc_of_gen_output),
logits=disc_of_gen_output)
total_disc_loss = real_loss + generated_loss
return total_disc_loss
# 生成器的loss
def generator_loss(disc_of_gen_output, lsgan=True):
if lsgan: # Use least squares loss
# gen_loss = tf.reduce_mean(tf.squared_difference(disc_of_gen_output, 1))
gen_loss = keras.losses.mean_squared_error(disc_of_gen_output, tf.ones_like(disc_of_gen_output))
else: # Use vanilla GAN loss
raise NotImplementedError
gen_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(disc_generated_output),
logits=disc_generated_output)
# l1_loss = tf.reduce_mean(tf.abs(target - gen_output)) # Look up pix2pix loss
return gen_loss
# 循环loss
def cycle_consistency_loss(data_A, data_B, reconstructed_data_A, reconstructed_data_B, cyc_lambda=10):
loss = tf.reduce_mean(tf.abs(data_A - reconstructed_data_A) + tf.abs(data_B - reconstructed_data_B))
return cyc_lambda * loss
训练的一些细节也要注意到,一次训练过程如下:
with tf.GradientTape() as genA2B_tape, tf.GradientTape() as genB2A_tape, \
tf.GradientTape() as discA_tape, tf.GradientTape() as discB_tape:
try:
trainA = next(train_datasetA)
trainB = next(train_datasetB)
except tf.errors.OutOfRangeError:
print("Error, run out of data")
break
# Cycle Gan循环机制
# 生成器:输入A生成B再生成A 输入B生成A再生成B
# 鉴别器:鉴别A和B生成的A 鉴别B和A生成的B
genA2B_output = genA2B(trainA, training=True)
genB2A_output = genB2A(trainB, training=True)
discA_real_output = discA(trainA, training=True)
discB_real_output = discB(trainB, training=True)
discA_fake_output = discA(genB2A_output, training=True)
discB_fake_output = discB(genA2B_output, training=True)
reconstructedA = genB2A(genA2B_output, training=True)
reconstructedB = genA2B(genB2A_output, training=True)
# 计算损失方法
# 鉴别器A:真实图片A与1之间的距离+B生成假图片A与0之间的距离(距离可以是sigmoid的交叉熵,也可以是MSE)
# 生成器A2B:生成假图片B与1之间的距离+循环距离
# 循环距离=(真实图片A与再生成图片A的绝对差值+真实图片B与再生成图片B的绝对差值)*权值
discA_loss = discriminator_loss(discA_real_output, discA_fake_output, lsgan=lsgan)
discB_loss = discriminator_loss(discB_real_output, discB_fake_output, lsgan=lsgan)
genA2B_loss = generator_loss(discB_fake_output, lsgan=lsgan) + \
cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB,
cyc_lambda=cyc_lambda)
genB2A_loss = generator_loss(discA_fake_output, lsgan=lsgan) + \
cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB,
cyc_lambda=cyc_lambda)
# 先更新生成器的系数,再更新鉴别器的系数
genA2B_gradients = genA2B_tape.gradient(genA2B_loss, genA2B.trainable_variables)
genB2A_gradients = genB2A_tape.gradient(genB2A_loss, genB2A.trainable_variables)
discA_gradients = discA_tape.gradient(discA_loss, discA.trainable_variables)
discB_gradients = discB_tape.gradient(discB_loss, discB.trainable_variables)
genA2B_optimizer.apply_gradients(zip(genA2B_gradients, genA2B.trainable_variables))
genB2A_optimizer.apply_gradients(zip(genB2A_gradients, genB2A.trainable_variables))
discA_optimizer.apply_gradients(zip(discA_gradients, discA.trainable_variables))
discB_optimizer.apply_gradients(zip(discB_gradients, discB.trainable_variables))
在面对不同的数据集和现实需求,还是要学很多调参技巧的,这个后续更新吧。