生成对抗网络(GAN)可以用于生成、风格迁移、数据增强、超分辨率等任务。今天介绍一篇 ICCV 2019 的 paper: “AutoGAN: Neural Architecture Search for Generative Adversarial Networks”。这篇文章第一次把 NAS 和 GAN 结合,想要用神经网络结构搜索(NAS)的方法搜一个GAN 的网络结构。
作者是来自 Texas A&M University 和 MIT-IBM Watson AI Lab 的 Xinyu Gong, Shiyu Chang, Yifan Jiang, Zhangyang Wang,代码开源在这里。
作者指出,将 NAS 和 GAN 结合,会遇到下面的问题:
问题1: GAN 由生成器 G 和判别器 D 组成。那么应该先固定住其中一个的结构,搜另一个吗?还是说这两个应该一起同时搜呢?
如果先固定一个搜另一个,可能会导致两者之间的不平衡;而如果两个一起搜,GAN 的训练不稳定,可能会遇到 Mode collapse 之类的问题。
作者的解决方法是: 只搜 G ,但是 D 并不是一成不变的。当 G 变的越来越深的时候, D 也会堆叠一些预先确定的 block 来变深。
问题2: 没有一个好的评价指标来给搜索过程提供反馈。
GAN 常用的指标是 Inception score (IS) 和 FID score,由于 FID score 计算比较慢,作者就选择用 Inception score 作为 RL 的 award。
下面介绍 AutoGAN 的搜索空间。作者采用了 Multi-Level Architecture Search (MLAS) 的策略,就是生成器由不同的 cell 组成,在搜索的时候一个 cell 对应一个 RNN controller 。搜索空间如下图所示:
第 s 个 cell 的搜索空间就可以由一个形状为 (s+5) 的元组 ( s k i p 1 , . . . , s k i p s , C , N , U , S C ) (skip_1, ..., skip_s, C, N, U, SC) (skip1,...,skips,C,N,U,SC) 来表示。其中 s k i p i skip_i skipi 表示当前的 cell 和前面坐标为 i − 1 i-1 i−1 的 cell 之间的 skip connection。C 表示卷积 block 的类型,有激活函数放在 conv 前/后两种;N 表示 normalization 的类型,有 BN / IN / 不加 normalization 三种;U 表示上采样操作的类型,有双线性差值、最近邻差值、stride 为 2 的反卷积三种;SC 表示 cell 内部要不要加 shortcut 连接。
这个搜索空间其实还算是比较简单的,一个 cell 里面就有一次上采样,再过两个 conv,中间加一些 skip-connection。
下面介绍 AutoGAN 的搜索策略:作者用的是 RL + RNN controller 的方法,训练的时候一共要更新两组参数:一个是 RNN controller 的参数 θ \theta θ,一个是 GAN 的生成器和判别器的参数 ω \omega ω。
由于训练 GAN 的时候是不稳定的,如果模型已经 collapsed 的话,就没有必要继续训练了。作者根据经验归纳出一个结论:如果训练 loss 的方差变得比较小,那么很可能就是发生了 mode collapse。作者提出了一种 dynamic-resetting 的策略:用一个滑动窗口来存储生成器和判别器的 training loss,如果方差小于一个阈值,当前 GAN 的训练就会终止,生成器和判别器的参数会重新初始化。不过 RNN controller 的参数还是会保留的。
固定住 θ \theta θ,只更新 GAN 的参数 ω \omega ω。从 RNN controller 得到出一堆候选的结构,用 hinge adversarial loss 来训练 GAN。同时会根据 training loss 的计算来提前终止已经发生 mode collapse 的模型。
固定住 GAN 的参数 ω \omega ω,只更新 RNN 的参数 θ \theta θ。作者用的是一个 LSTM,首先采样出 K 个 child models,然后计算对应的 inception score 作为 reward,随后用强化学习的方式更新 LSTM 的权重。
作者用的数据集是 CIFAR-10 (分辨率 32 x 32) 和 STL-10 (分辨率 48 x 48)。作者在 CIFAR-10 上搜到的生成器结构如下图所示:
可以看出,这个搜出来的结构倾向于把激活函数加在卷积的前面、使用双线性插值而不是反卷积、不使用 normalization 运算、加很多 skipping connections 运算。
和其他方法的对比如下表所示:
作者指出,本文方法的搜索空间和 SN-GAN 比较像,因此和 SN-GAN 对比能看出来确实比人工设计的方法好。而有些方法用到了本文的搜索空间中没有的运算,例如 WGAN-GP 用到了 Wasserstein loss ,和他们比其实不太公平。而 SN-GAN 是基于 ResNet Block 的,并且在判别器部分移除了 BN 。这么看的话,本文搜出来的结构主要是多了 cell 之间的 skip connection。
作者为了证明搜出来的结构没有过拟合 CIFAR 这个小数据集,还提供了在 STL-10 上面的结果。作者保留在 CIFAR 上得到的结构不变,在 STL-10 上面重新训练,得到的结果如下表所示:
这个结果在 FIN 指标上比 Improving MMD GAN 要好。
AutoGAN 在 CIFAR 数据集上生成得到的图片效果如下图所示:
作者指出,AutoGAN 还有很大的提升空间,都是后面可以继续做的点:
综合来看,这个 NAS + GAN 的坑还有很多可以填的地方。本文的贡献主要在于第一次把 NAS + GAN 这种东西搞 work。
如果有什么理解不到位的地方,欢迎在评论区指正。