对抗生成网络GAN系列——WGAN原理及实战演练

作者简介:秃头小苏,致力于用最通俗的语言描述问题

往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例   对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例   对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战   对抗生成网络GAN系列——EGBAD原理及缺陷检测实战

近期目标:写好专栏的每一篇文章

支持小苏:点赞、收藏⭐、留言

 

文章目录

  • 对抗生成网络GAN系列——WGAN原理及实战演练
    • 写在前面
    • WGAN原理详解
      • GAN为什么训练困难
      • EM distance(推土机距离)的引入
      • WGAN的实现
    • WGAN代码实战
    • 小结
    • 参考链接

对抗生成网络GAN系列——WGAN原理及实战演练

写在前面

  在前面我已经写过好几篇关于GAN的文章,感兴趣的可以点击下列链接了解详情:

  • [1]对抗生成网络GAN系列——GAN原理及手写数字生成小案例
  • [2]对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例
  • [3]对抗生成网络GAN系列——CycleGAN原理
  • [4] 对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战
  • [5]对抗生成网络GAN系列——EGBAD原理及缺陷检测实战

  在文章[5]的最后,我们从训练结果中可以看出效果并不是很理想,其实难以训练是GAN网络训练普遍存在的问题。本篇文章将为大家带来WGAN(Wasserstein Generative Adversarial Networks ),旨在解决GAN难train的问题。这篇论文中有大量的理论推导,我不会带着大家一个个推,当然很多我也不会,但是我尽可能的把一些关键部分给大家叙述清楚,让大家从心底认可WGAN,觉得WGAN是合理的,是美妙的。

  准备好了吗,就让我们一起来学学WGAN吧!

 

WGAN原理详解

GAN为什么训练困难

​  上文我们谈到GAN普遍存在训练困难的问题,而WGAN是用来解决此问题的。所谓对症下药,我们第一步应该要知道GAN为什么会存在训练难的问题,这样我们才能从本质上对其进行改进。下面就跟随我的步伐一起来看看吧!!!

​  还记得我们在[1]中给出的GAN网络的损失函数吗?如下图所示:

对抗生成网络GAN系列——WGAN原理及实战演练_第1张图片

​  为了方便表述,对上图公式做如下调整:

min ⁡ G max ⁡ D V = E x ∼ p r [ log ⁡ D ( x ) ] + E x ∼ p g [ log ⁡ ( 1 − D ( x ) ) ] \mathop {\min }\limits_G \mathop {\max }\limits_D V = {E_{x \sim {{\rm{p}}_r}}}[\log D(x)] + {E_{x \sim {{\rm{p}}_g}}}[\log (1 - D(x))] GminDmaxV=Expr[logD(x)]+Expg[log(1D(x))]

​  即把后一项的 G ( z ) G(z) G(z)用x来表示,x为生成器生成的分布。

  我们现在希望求得生成器固定是最大化的判别器D。首先,我们对于一个随机的样本,它可能是真实样本,也可能是生成的样本,那么对于这个样本的损失为:

P r ( x ) log ⁡ D ( x ) + P g ( x ) log ⁡ [ 1 − D ( x ) ] {{\rm{P}}_r}(x)\log D(x) + {P_g}(x)\log [1 - D(x)] Pr(x)logD(x)+Pg(x)log[1D(x)]

​ 预求最大化的D,则对上式的D(x)求导,并让上式导函数为0,即:

P r ( x ) D ( x ) − P g ( x ) 1 − D ( x ) = 0 \frac{{{{\rm{P}}_r}(x)}}{{D(x)}} - \frac{{{{\rm{P}}_g}(x)}}{{1 - D(x)}} = 0 D(x)Pr(x)1D(x)Pg(x)=0

​ 化简上式,得最优的D表达式为:

D ∗ ( x ) = P r ( x ) P r ( x ) + P g ( x ) ——式 1 {D^*}(x) = \frac{{{{\rm{P}}_r}(x)}}{{{{\rm{P}}_r}(x) + {{\rm{P}}_g}(x)}} \qquad \qquad \qquad—— 式1 D(x)=Pr(x)+Pg(x)Pr(x)——1

​  我们可以来看看这个式1是否符合我们的直觉。若 p r ( x ) = 0 、 p g ( x ) ≠ 0 p_r(x)=0、p_g(x)\ne 0 pr(x)=0pg(x)=0表示x完全服从生成分布,是假图像,所以最优判别器D给出的概率为0;若 p r ( x ) ≠ 0 、 p g ( x ) = 0 p_r(x) \ne 0、p_g(x)= 0 pr(x)=0pg(x)=0表示x完全服从真实分布,是真图像,所以最优判别器D给出的概率为1;若 p r ( x ) = p g ( x ) p_r(x)=p_g(x) pr(x)=pg(x),表示x服从生成分布和服从真实分布的数据一样多,因此最优判别器D给出的概率为50%。

​​  如今我们已经找到了最大化的D,现要将其固定,寻找最小化的G,即求下式:

min ⁡ G V = E x ∼ p r [ log ⁡ D ( x ) ] + E x ∼ p g [ log ⁡ ( 1 − D ( x ) ) ] ——式 2 \mathop {\min }\limits_G V = {E_{x \sim {{\rm{p}}_r}}}[\log D(x)] + {E_{x \sim {{\rm{p}}_g}}}[\log (1 - D(x))] \qquad \qquad —— 式2 GminV=Expr[logD(x)]+Expg[log(1D(x))]——2

​  将最大化的D即式1的 D ∗ ( x ) D^*(x) D(x)代入式2得:

min ⁡ G V = E x ∼ p r [ log ⁡ ( p r ( x ) p r ( x ) + p g ( x ) ) ] + E x ∼ p g [ log ⁡ ( p g ( x ) p r ( x ) + p g ( x ) ) ] \mathop {\min }\limits_G V = {E_{x \sim {{\rm{p}}_r}}}[\log (\frac{{{{\rm{p}}_r}(x)}}{{{{\rm{p}}_r}(x) + {{\rm{p}}_g}(x)}})] + {E_{x \sim {{\rm{p}}_g}}}[\log (\frac{{{{\rm{p}}_g}(x)}}{{{{\rm{p}}_r}(x) + {{\rm{p}}_g}(x)}})] GminV=Expr[log(pr(x)+pg(x)pr(x))]+Expg[log(pr(x)+pg(x)pg(x))]

​​  再化简一步得式3:

min ⁡ G V = E x ∼ p r [ log ⁡ ( p r ( x ) 1 2 p r ( x ) + p g ( x ) ) ] + E x ∼ p g [ log ⁡ ( p g ( x ) 1 2 p r ( x ) + p g ( x ) ) ] − 2 log ⁡ 2 ——式 3 \mathop {\min }\limits_G V = {E_{x \sim {{\rm{p}}_r}}}[\log (\frac{{{{\rm{p}}_r}(x)}}{{\frac{1}{2}{{\rm{p}}_r}(x) + {{\rm{p}}_g}(x)}})] + {E_{x \sim {{\rm{p}}_g}}}[\log (\frac{{{{\rm{p}}_g}(x)}}{{\frac{1}{2}{{\rm{p}}_r}(x) + {{\rm{p}}_g}(x)}})] - 2\log 2 \qquad \qquad —— 式3 GminV=Expr[log(21pr(x)+pg(x)pr(x))]+Expg[log(21pr(x)+pg(x)pg(x))]2log2——3

​  化简成式3的形式是为了和KL散度和JS散度相联系。我们先来介绍KL散度和JS散度的定义:【大家不要觉得散度这个词好像很高大上一样,大家其实就可以理解为距离】

  • KL散度

    K L ( P 1 ∣ ∣ P 2 ) = E x ∼ P 1 l o g P 1 P 2 KL(P_1||P_2)=E_{x \sim P_1}log \frac {P_1}{P_2} KL(P1∣∣P2)=ExP1logP2P1

    上式是用期望来表示KL散度的,我们也可以用积分或求和形式来表达,如下:

    K L ( P 1 ∣ ∣ P 2 ) = ∫ x P 1 log ⁡ P 1 P 2 d x KL(P_1||P_2)=\int\limits_x {{P_1}\log \frac{{{P_1}}}{{{P_2}}}} dx KL(P1∣∣P2)=xP1logP2P1dx K L ( P 1 ∣ ∣ P 2 ) = ∑ p 1 log ⁡ P 1 P 2 KL(P_1||P_2)= \sum {{p_1}\log } \frac{{{P_1}}}{{{P_2}}} KL(P1∣∣P2)=p1logP2P1

    我来大致的解释一下KL散度的含义,用 K L ( P 1 ∣ ∣ P 2 ) = ∑ P 1 ( − l o g P 2 − ( − l o g P 1 ) ) KL(P_1||P_2)=\sum {{P_1}( - log{P_2} - ( - log{P_1}))} KL(P1∣∣P2)=P1(logP2(logP1))解释,如下图:

    对抗生成网络GAN系列——WGAN原理及实战演练_第2张图片

    ​ 关于KL散度我就介绍这么多,它还有一些性质,像非对称性等等,我这里就不过多介绍了,感兴趣的可以自己去阅读阅读相关资料。

  • JS散度
    J S ( P 1 ∣ ∣ P 2 ) = 1 2 K L ( P 1 ∣ ∣ P 1 + P 2 2 ) + 1 2 K L ( P 2 ∣ ∣ P 1 + P 2 2 ) JS({P_1}||{P_2}) = \frac{1}{2}KL({P_1}||\frac{{{P_1} + {P_2}}}{2}) + \frac{1}{2}KL({P_2}||\frac{{{P_1} + {P_2}}}{2}) JS(P1∣∣P2)=21KL(P1∣∣2P1+P2)+21KL(P2∣∣2P1+P2)

    可以看出,JS散度是根据KL散度来定义的,它具有对称性。


​ 有了JS散度的定义,我们就可以将式3变换成如下形式:

min ⁡ G V = 2 J S ( P r ∣ ∣ P g ) − 2 log ⁡ 2 ——式 4 \mathop {\min }\limits_G V = 2JS({P_r}||{P_g}) - 2\log 2 \qquad \qquad ——式4 GminV=2JS(Pr∣∣Pg)2log2——4

​​  这时候肯定很多人会想了,为什么我们需要把式3化成式4这个JS散度的形式,这是因为我们的GAN训练困难很大原因就是这个JS散度捣的鬼。

​​  为什么说训练困难是JS散度捣的鬼呢?通过上文我们知道,我们将判别器训练的越好,即判别器越接近最优判别器,此时生成器的loss会等价为式4中的JS散度的形式,当我们训练生成器来最小化生成器损失时也就是最小化式4中的JS散度。这这过程看似非常的合理,只要我们不断的训练,真实分布和生成分布越来越接近,JS散度应该越来越小,直到两个分布完全一致,此时JS散度为0。

​ 但是理想很丰满,现实很骨感,JS散度并不会随着真实分布和生成分布越来越近而使其值越来越小,而是很大概率保持 l o g 2 log2 log2不变。

对抗生成网络GAN系列——WGAN原理及实战演练_第3张图片

​​  如上图所示,随着 p g p_g pg p d a t a p_{data} pdata(也就是上文中所述的 p r p_r pr)越来越近,它们的JS散度却不发生变化,一直等于log2,只有当 p g p_g pg p d a t a p_{data} pdata有重叠时才会发生变化,当完全重叠时,JS散度的值为0。也就是说,不论 p g p_g pg p d a t a p_{data} pdata相距多远,只要它们没有重叠,那么它们的JS散度就一直是常数log2,也就是在训练过程中loss一直不发生变化,这也意味着生成器的梯度一直为0,即生成器产生了梯度消失的现象。

​​  读了上文,我觉得你至少应该有两个疑问,其一是为什么两个分布不重合,它们JS散度一直为log2,其二为既然只要 p g p_g pg p d a t a p_{data} pdata有重叠,那么它们的JS散度就不是log2,这样我们是不是就可以更新梯度了呢?那么它们重叠的概率大吗?我们一个一个的来解决大家的疑惑。

  1. 什么两个分布不重合,它们JS散度一直为log2

    这个其实是由JS散度的性质来决定的,这里带大家推一推。首先JS散度的表达式如下:

    J S ( P 1 ∣ ∣ P 2 ) = 1 2 K L ( P 1 ∣ ∣ P 1 + P 2 2 ) + 1 2 K L ( P 2 ∣ ∣ P 1 + P 2 2 ) JS({P_1}||{P_2}) = \frac{1}{2}KL({P_1}||\frac{{{P_1} + {P_2}}}{2}) + \frac{1}{2}KL({P_2}||\frac{{{P_1} + {P_2}}}{2}) JS(P1∣∣P2)=21KL(P1∣∣2P1+P2)+21KL(P2∣∣2P1+P2)

    上文我也给出了KL散度的公式,将KL散度公式代入上式,得:

    J S ( P 1 ∣ ∣ P 2 ) = 1 2 ∑ p 1 ( x ) log ⁡ 2 p 1 ( x ) p 1 ( x ) + p 2 ( x ) + 1 2 ∑ p 2 ( x ) log ⁡ 2 p 2 ( x ) p 1 ( x ) + p 2 ( x ) JS({P_1}||{P_2}) = \frac{1}{2}\sum {{p_1}(x)\log } \frac{{2{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{2{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}} JS(P1∣∣P2)=21p1(x)logp1(x)+p2(x)2p1(x)+21p2(x)logp1(x)+p2(x)2p2(x)

    再化简上式,将log2提出来,注意这里可以提是因为 ∑ p 1 ( x ) = ∑ p 2 ( x ) = 1 \sum {{p_1}(x) = } \sum {{p_2}(x) = } 1 p1(x)=p2(x)=1,提取后公式如下:

    J S ( P 1 ∣ ∣ P 2 ) = 1 2 ∑ p 1 ( x ) log ⁡ p 1 ( x ) p 1 ( x ) + p 2 ( x ) + 1 2 ∑ p 2 ( x ) log ⁡ p 2 ( x ) p 1 ( x ) + p 2 ( x ) + log ⁡ 2 JS({P_1}||{P_2}) = \frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}} + \log 2 JS(P1∣∣P2)=21p1(x)logp1(x)+p2(x)p1(x)+21p2(x)logp1(x)+p2(x)p2(x)+log2

​   现在大家注意到上式已经有一个log2了,为证明我们的结论,我们只需证明两个分布不重合时上式左侧两项为0即可。我们以下图为例进行讲解:

对抗生成网络GAN系列——WGAN原理及实战演练_第4张图片

​  对于上图中的两个分布 P 1 和 P 2 P_1和P_2 P1P2,可以看出它们没有重叠部分。我们设 p 1 ( x ) p_1(x) p1(x)为x落在 P 1 P_1 P1分布上的概率,设 p 2 ( x ) p_2(x) p2(x)为x落在 P 2 P_2 P2分布上的概率,我们来看 J S ( P 1 ∣ ∣ P 2 ) JS(P_1||P_2) JS(P1∣∣P2)的左边两项,即 1 2 ∑ p 1 ( x ) log ⁡ p 1 ( x ) p 1 ( x ) + p 2 ( x ) + 1 2 ∑ p 2 ( x ) log ⁡ p 2 ( x ) p 1 ( x ) + p 2 ( x ) \frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}} 21p1(x)logp1(x)+p2(x)p1(x)+21p2(x)logp1(x)+p2(x)p2(x)

​ 当 x < 5 x<5 x<5时, p 2 ( x ) p_2(x) p2(x)近似为0,则:

1 2 ∑ p 1 ( x ) log ⁡ p 1 ( x ) p 1 ( x ) + p 2 ( x ) + 1 2 ∑ p 2 ( x ) log ⁡ p 2 ( x ) p 1 ( x ) + p 2 ( x ) = 1 2 ∑ p 1 ( x ) log ⁡ p 1 ( x ) p 1 ( x ) + 0 + 1 2 ∑ 0 log ⁡ 0 p 1 ( x ) + 0 = 0 \frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}} = \frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + 0}} + \frac{1}{2}\sum {0\log } \frac{0}{{{p_1}(x) + 0}} = 0 21p1(x)logp1(x)+p2(x)p1(x)+21p2(x)logp1(x)+p2(x)p2(x)=21p1(x)logp1(x)+0p1(x)+210logp1(x)+00=0

​ 当 x ≥ 5 x \ge 5 x5时, p 1 ( x ) p_1(x) p1(x)近似为0,则:

1 2 ∑ p 1 ( x ) log ⁡ p 1 ( x ) p 1 ( x ) + p 2 ( x ) + 1 2 ∑ p 2 ( x ) log ⁡ p 2 ( x ) p 1 ( x ) + p 2 ( x ) = 1 2 ∑ 0 log ⁡ 0 0 + p 2 ( x ) + 1 2 ∑ p 2 ( x ) log ⁡ p 2 ( x ) 0 + p 2 ( x ) = 0 \frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}} = \frac{1}{2}\sum {0\log } \frac{0}{{0 + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{0 + {p_2}(x)}} = 0 21p1(x)logp1(x)+p2(x)p1(x)+21p2(x)logp1(x)+p2(x)p2(x)=210log0+p2(x)0+21p2(x)log0+p2(x)p2(x)=0

​ 因此对于任意的x都会有 1 2 ∑ p 1 ( x ) log ⁡ p 1 ( x ) p 1 ( x ) + p 2 ( x ) + 1 2 ∑ p 2 ( x ) log ⁡ p 2 ( x ) p 1 ( x ) + p 2 ( x ) = 0 \frac{1}{2}\sum {{p_1}(x)\log } \frac{{{p_1}(x)}}{{{p_1}(x) + {p_2}(x)}} + \frac{1}{2}\sum {{p_2}(x)\log } \frac{{{p_2}(x)}}{{{p_1}(x) + {p_2}(x)}}=0 21p1(x)logp1(x)+p2(x)p1(x)+21p2(x)logp1(x)+p2(x)p2(x)=0 ,进而会有 J S ( P 1 ∣ ∣ P 2 ) = l o g 2 JS(P_1||P_2)=log2 JS(P1∣∣P2)=log2


这部分主要参考了这篇文章,我认为其对JS散度为什么等于log2解释的非常清楚


  1. 任意的两个分布 p g p_g pg p d a t a p_{data} pdata重叠的概率有多大

    ​ 首先,我直接给大家明确结论,重叠概率很小,特别小。我说明一下这里所说的重叠概率指的是重叠部分不可忽略的概率,也就是说 p g p_g pg p d a t a p_{data} pdata可能会重叠,但是往往重叠部分都是可以忽略的。这部分大家可以参考这篇文章,写的非常好,对于这部分的解释也比较全面。

    ​ 我从两个角度帮助大家理解,其一, p g p_g pg p d a t a p_{data} pdata假设为一条曲线, p g p_g pg开始往往都是随机生成的,所以很难和 p d a t a p_{data} pdata重合,往往都是有一些交点,而这些交点是可以忽略的。实际上, p g p_g pg p d a t a p_{data} pdata的维度往往很高,它们完全重叠的概率非常之低。

    对抗生成网络GAN系列——WGAN原理及实战演练_第5张图片

    ​ 其二,我们往往是从 p g p_g pg p d a t a p_{data} pdata中采样一些点出来训练,即使 p g p_g pg p d a t a p_{data} pdata有重叠,但是经过采样后,它们很大可能就没有重叠了。

    对抗生成网络GAN系列——WGAN原理及实战演练_第6张图片


​  解释了这么多,大家应该知道为什么GAN会训练困难了吧。最后总结一下,这是因为当判别器训练最优时,生成器的损失函数等价于JS散度,其梯度往往一直为0,得不到更新,所以很难train。


​​  我不知道大家是否会有这样的疑问,既然说生成器梯度得不大更新,为什么我们在训练GAN时还是能得到一些较好的图片呢,不应该什么也得不到呢。我是这样认为的,我们上述的推导都是建立在最优判别器的基础上的,但是在我们实操过程中往往一开始判别器性能是不理想的,所以生成器还是有梯度更新的。还有一点,就是我之前的文章代码中生成器的损失其实使用的 − l o g D ( x ) -logD(x) logD(x),这种损失函数也会存在一些问题,往往出现梯度不稳定,模式崩溃问题,但是梯度消失问题出现比较少,所以还是能训练的。关于使用这种损失函数带来的缺陷我这里也不叙述了,大家感兴趣的可以看看这篇文章。

 

EM distance(推土机距离)的引入

​​  我们上文详细分析了普通GAN存在的缺陷,主要是由于和JS散度相关的损失函数导致的。大佬们就在思考能否有一种损失能够代替JS散度呢?于是,WGAN应运而生,其提出了一种新的度量两个分布距离的标准——Wasserstein Metric,也叫推土机距离(Earth-Mover distance)。下面就让我们来看看什么是推土机距离吧!!!


【这部分参考此篇文章】

​  下图左侧有6个盒子,我们期望将它们都移动到右侧的虚线框内。比如将盒子1从位置1移动到位置7,移动了6步,我们就将此步骤的代价设为6;同理,将盒子5从位置3移动到位置10,移动了7步,那么此步骤的代价为7,依此类推。

对抗生成网络GAN系列——WGAN原理及实战演练_第7张图片

图1 EM distance示例

​ 很显然,我们有很多种不同的方案,下图给出了两种不同的方案:

对抗生成网络GAN系列——WGAN原理及实战演练_第8张图片

​​  上图中右侧的表格表示盒子是如何移动的。比如在 γ 1 {\gamma _1} γ1中,其第一行第一列的值为1,表示有1个框从位置1移动到了位置7;第一行的第四列的值为2,表示有2个框从位置1移动到了位置10。上图中给出的两种方案的代价总和都为42,但是对于一个问题并不是所有的移动方案的代价都是固定的,比如下图:

对抗生成网络GAN系列——WGAN原理及实战演练_第9张图片

​​  在这个例子中,上图展示的两种方案的代价是不同的,一个为2,一个为6。,而推土机距离就是穷举所有的移动方案,最小的移动代价对应的就是推土机距离。对应本列来说,推土机距离等于2。

​​  相信通过上文的表述,你已经对推土机距离有了一定了了解。现给出推土机距离是数学定义,如下:

W ( P r , P g ) = inf ⁡ γ ∈ ∏ ( P r , P g ) E ( x , y ) [ ∣ ∣ x − y ∣ ∣ ] {\rm{W}}({P_r},{P_g}) = \mathop {\inf }\limits_{\gamma \in \prod ({P_r},{P_g})} {E_{(x,y)}}[||x - y||] W(Pr,Pg)=γ(Pr,Pg)infE(x,y)[∣∣xy∣∣]

​​  看到这个公式大家是不是都懵逼了呢,这里做相关的解释。 ∏ ( P r , P g ) {\prod ({P_r},{P_g})} (Pr,Pg)表示边缘分布 P r 和 P g P_r和P_g PrPg所有组合起来的联合分布 γ ( x , y ) \gamma(x,y) γ(x,y)的集合。我们还是用图1中的例子来解释, ∏ ( P r , P g ) {\prod ({P_r},{P_g})} (Pr,Pg)就表示所有的运输方案 γ \gamma γ,下图仅列举了两种方案:

对抗生成网络GAN系列——WGAN原理及实战演练_第10张图片

​   E ( x , y ) [ ∣ ∣ x − y ∣ ∣ ] {E_{(x,y)}}[||x - y||] E(x,y)[∣∣xy∣∣] 可以看成是对于一个方案 γ \gamma γ 移动的代价, inf ⁡ γ ∈ ∏ ( P r , P g ) E ( x , y ) [ ∣ ∣ x − y ∣ ∣ ] \mathop {\inf }\limits_{\gamma \in \prod ({P_r},{P_g})} {E_{(x,y)}}[||x - y||] γ(Pr,Pg)infE(x,y)[∣∣xy∣∣] 就表示在所有的方案中的最小代价,这个最小大家就是 W ( P r , P g ) W(P_r,P_g) W(Pr,Pg),即推土机距离。


​  现在我们已经知道了推土机距离是什么,但是我们还没解释清楚我们为什么要用推土机距离,即推土机距离为什么可以代替JS散度成为更优的损失函数?我们来看这样的一个例子,如下图所示:

对抗生成网络GAN系列——WGAN原理及实战演练_第11张图片

​​  上图有两个分布 P 1 和 P 2 P_1和P_2 P1P2 P 1 P_1 P1在线段AB上均匀分布, P 2 P_2 P2在CD上均匀分布,参数 θ \theta θ可以控制两个分布的距离。我们由前文对JS散度的解释,可以得到:

J S ( P 1 ∣ ∣ P 2 ) = { log ⁡ 2 θ ≠ 0 0    θ = 0 JS({P_1}||{P_2}) = \left\{ \begin{array}{l} \log 2{\rm{ }} \quad \theta \ne {\rm{0}}\\ 0{\rm{ }} \quad \quad \ \ \theta {\rm{ = 0}} \end{array} \right. JS(P1∣∣P2)={log2θ=00  θ=0

​​  而对于推土机距离来说,可以得到:

W ( P 1 , P 2 ) = ∣ θ ∣ W(P_1,P_2)=|\theta| W(P1,P2)=θ

​  这样对比可以看出,推土机距离是平滑的,这样在训练时,即使两个分布不重叠推土机距离仍然可以提高梯度,这一点是JS散度无法实现的。
 


WGAN的实现

​​  现在我们已经有了推土机距离的定义,同时也解释了推土机距离相较于JS散度的优势。但是我们想要直接使用推土机距离来定义生成器的损失似乎是困难的,因为这个式子 W ( P r , P g ) = inf ⁡ γ ∈ ∏ ( P r , P g ) E ( x , y ) [ ∣ ∣ x − y ∣ ∣ ] {\rm{W}}({P_r},{P_g}) = \mathop {\inf }\limits_{\gamma \in \prod ({P_r},{P_g})} {E_{(x,y)}}[||x - y||] W(Pr,Pg)=γ(Pr,Pg)infE(x,y)[∣∣xy∣∣] 是难以直接求解的,但是呢,作者大大用了一个定理将上式变化成了如下形式:

W ( P r , P g ) = 1 k sup ⁡ ∣ ∣ f ∣ ∣ L ≤ K E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] {\rm{W}}({P_r},{P_g}) = \frac{1}{k}\mathop {\sup }\limits_{||f|{|_L} \le K} {E_{x \sim {{\rm{P}}_r}}}[f(x)] - {E_{x \sim {{\rm{P}}_g}}}[f(x)] W(Pr,Pg)=k1∣∣fLKsupExPr[f(x)]ExPg[f(x)]

​​  注意上式中的f有一个限制,即 ∣ ∣ f ∣ ∣ L ≤ K {||f|{|_L} \le K} ∣∣fLK,我们称为lipschitz连续条件。这个限制其实就是限制了函数f的导数。它的定义如下:

∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ K ∣ x 1 − x 2 ∣ |f(x_1)-f(x_2)| \le K|x_1-x2| f(x1)f(x2)Kx1x2∣

​ 即:

∣ f ( x 1 ) − f ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ K \frac{|f(x_1)-f(x_2)|}{|x_1-x2|} \le K x1x2∣f(x1)f(x2)K

​​  很显然,lipschitz连续就限制了f的斜率的绝对值小于等于K,这个K称为Libschitz常数。我们来举个例子帮助大家理解,如下图所示:

对抗生成网络GAN系列——WGAN原理及实战演练_第12张图片

​  上图中log(x)的斜率无界,故log(x)不满足lipschitz连续条件;而sin(x)斜率的绝对值都小于1,故sin(x)满足lipschitz连续条件。


​  这样,我们只需要找到一个lipschitz函数,就可以计算推土机距离了。至于怎么找这个lipschitz函数呢,就是我们搞深度学习的那一套啦,只需要建立一个深度学习网络来进行学习就好啦。实际上,我们新建立的判别器网络和之前的的基本是一致的,只是最后没有使用sigmoid函数,而是直接输出一个分数,这个分数可以反应输入图像的真实程度。

​​  呼呼呼~~,WGAN的原理就为大家介绍到这里了,大家掌握了多少呢?其实我认为只看一篇文章是很难把WGAN的所有细节都理解的,大家可以看看本文的参考文献,结合多篇文章看看能不能帮助大家解决一些困惑。

 

WGAN代码实战

​  WGAN的代码实战我不打算贴出一堆代码了,只说明一下WGAN相较于普通GAN做了哪些改变。首先我们给出论文中训练WGAN的流程图,如下:

对抗生成网络GAN系列——WGAN原理及实战演练_第13张图片

​ 其实WGAN相较于原始GAN只做了4点改变,分别如下:

  1. 判别器最后不使用sigmoid函数
  2. 生成器和判别器的loss不取log
  3. 每次更新判别器参数后将判别器的权重截断
  4. 不适应基于动量的优化算法,推荐使用RMSProp

​  现在代码中分别对上述的4点做相关解释:【很简单,所以大家想要将原始GAN修改为WGAN就按照下列的几点来修改就好了】

  1. 判别器最后不使用sigmoid函数

    这个我们一般只需要删除判别器网络中的最后一个sigmoid层就可以了,非常简单。但是我还想提醒大家一下,有时候你在看别人的原始GAN时,他的判别器网络中并没有sigmoid函数,而是在定义损失函数时使用了BCEWithLogitsLoss函数,这个函数会先对数据做sigmoid,相关代码如下:

    # 定义损失函数
    criterion = nn.BCEWithLogitsLoss(reduction='mean')
    

    如果这时你想删除sigmoid函数,只需将BCEWithLogitsLoss函数修改成BCELoss即可。关于BCEWithLogitsLossBCELoss的区别我这篇文章有相关讲解,感兴趣的可以看看。

  2. 生成器和判别器的loss不取log

    我们先来看原始GAN判别器的loss是怎么定义的,如下:

    d_loss_real = criterion(d_out_real.view(-1), label_real)
    d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
    d_loss = d_loss_real + d_loss_fake
    

    原始GAN使用了criterion函数,这就是我们上文定义的BCEWithLogitsLossBCELoss,其内部是一个log函数,不清楚的同样可以参考我的这篇文章。

    而WGAN判别器的损失直接是两个期望(均值)的差,就对应理论中 E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] {E_{x \sim {{\rm{P}}_r}}}[f(x)] - {E_{x \sim {{\rm{P}}_g}}}[f(x)] ExPr[f(x)]ExPg[f(x)] 这部分啦【但是需要加个负号,因为我们都是使用的梯度下降方法,要最小化这个损失,上式是最大化的公式】,我们来看看代码中的实现:

    d_loss = -(torch.mean(d_out_real.view(-1))-torch.mean(d_out_fake.view(-1)))
    

    看完判别器的损失,我们再来看生成器的损失:

    g_loss = criterion(d_out_fake.view(-1), label_real)  #原始GAN损失
    ---------------------------------------------------
    g_loss = -torch.mean(d_out_fake.view(-1))            #WGAN损失
    
  3. 每次更新判别器参数后将判别器的权重截断

    首先来说说没什么要进行权重截断,这是因为lipschitz连续条件不好确定,作者为了方便直接简单粗暴的限制了权重参数在[-c,c]这个范围了,这样就一定会存在一个常数K使得函数f满足lipschitz连续条件,具体的实现代码如下:

    # clip D weights between -0.01, 0.01  权重剪裁
      for p in D.parameters():
      	  p.data.clamp_(-0.01, 0.01)
    

    注意这步是在判别器每次反向传播更新梯度结束后进行的。还需要提醒大家一点,在训练WGAN时往往会多训练几次判别器然后再训练一次生成器,论文中是训练5次判别器后训练一次生成器,关于这一点在上文WGAN的流程图中也有所体现。为什么要这样做呢?我思考这样做可能是想把判别器训练的更好后再来训练生成器,因为在前面的理论部分我们的推导都是建立在最优判别器的前提下的。

  4. 不适应基于动量的优化算法,推荐使用RMSProp

    这部分就属于玄学部分了,作者做实验发现像Adam这类基于动量的优化算法效果不好,然后使用了RMSProp优化算法。我决定这部分大家也不要纠结,直接用就好。同样我们来看看代码是怎么写的,如下:

    optimizerG = torch.optim.RMSprop(G.parameters(), lr=5e-5)
    optimizerD = torch.optim.RMSprop(D.parameters(), lr=5e-5)
    

​  这样,WGAN的代码实战我就为大家介绍到这里,不管你前文的WGAN原理听明白了否,但是WGAN代码相信你是一定会修改的,改动的非常之少,大家快去试试吧。

 

小结

​  这部分的理论确实是有一定难度的,我也看了非常非常多的视频和博客,写了很多笔记。我觉得大家不一定要弄懂每一个细节,只要对里面的一些关键公式,关键思想有清晰的把握即可;而这部分的实验较原始GAN需要修改的仅有四点,非常简单,大家都可以试试。最后希望大家都能够有所收获,就像WGAN一样稳定的进步,一起加油吧!!!

 

参考链接

令人拍案叫绝的Wasserstein GAN

GAN:两者分布不重合JS散度为log2的数学证明

GAN — Wasserstein GAN & WGAN-GP

GAN — Spectral Normalization

李宏毅【機器學習2021】生成式對抗網路

 
 
如若文章对你有所帮助,那就

在这里插入图片描述

你可能感兴趣的:(视觉,生成对抗网络,深度学习,WGAN)