图像生成之生成对抗网络(Generative adversarial networks,GAN)

简要介绍

原文:Generative Adversarial Networks (arxiv.org)

​ 以前的一些模型是通过显示训练数据的分布,比如上面的VAE。也就是可以得到最终训练好的均值与方差,而GAN是隐式训练。

​ GAN受博弈论中的零和博弈启发,将生成问题视作判别器和生成器这两个网络的对抗和博弈:生成器从给定噪声中(一般是指均匀分布或者正态分布)产生合成数据,判别器分辨生成器的的输出和真实数据。前者试图产生更接近真实的数据,后者试图更完美地分辨真实数据与生成数据。由此,两个网络在对抗中进步,在进步后继续对抗,由生成网络得到的数据也就越来越完美,逼近真实数据,从而可以生成想要得到的数据(图片、序列、视频等)。

训练原理与公式推导

图像生成之生成对抗网络(Generative adversarial networks,GAN)_第1张图片

图1 GAN模型体系结构图示

​ 设真实训练数据的分布服从 x ∼ p data x \sim p_\text{data} xpdata,先验噪声输入变量分布服从 z ∼ q z ( z ) z \sim q_z(z) zqz(z), 生成模型为 G ( z ; θ g ) G(z;\theta_g) G(z;θg),其中 θ g \theta_g θg是生成模型中的网络参数。设判别模型为 D ( z ; θ d ) D(z;\theta_d) D(z;θd),其中 θ d \theta_d θd是判别模型中的网络参数,生成模型通过学习数据 x x x得到生成模型分布 p g p_g pg。根据理论分析结果,如果给予生成模型足够的承载力和足够的训练时间,生成模型 G G G完全有可能拟合出类似于真实样本分布 p data p_\text{data} pdata的从假设随机分布 p z p_z pz得到的合成样本分布 p g p_g pg。 同时,当模型训练达到全局最优平衡点 p g = p data p_g=p_\text{data} pg=pdata的时候,说明生成对抗网络达到了 博弈的最佳状态。数学优化模型如下所示:
min ⁡   G max ⁡   D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ ( D ( x ) ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( x ) ) ] \min_{\ G} \max_{\ D} V(D, G) = {E}_{x\sim p_{\text{data}}(x)}[\log(D(x))] +{E}_{z\sim p_z(z)}[\log(1 - D(G(x))]  Gmin DmaxV(D,G)=Expdata(x)[log(D(x))]+Ezpz(z)[log(1D(G(x))]

  • 生成器的目标函数为:

min ⁡   G E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( x ) ) ] \min_{\ G}E_{z\sim p_z(z)}[\log(1 - D(G(x))]  GminEzpz(z)[log(1D(G(x))]

  • 判别器的目标函数为:

max ⁡   D E x ∼ p data ( x ) [ log ⁡ ( D ( x ) ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( x ) ) ] \max_{\ D}E_{x\sim p_{\text{data}}(x)}[\log(D(x))] + E_{z\sim p_z(z)}[\log(1 - D(G(x))]  DmaxExpdata(x)[log(D(x))]+Ezpz(z)[log(1D(G(x))]

​ 由于判别网络和生成网络得到的极值点可能并不是在同一个位置,那么便存在一个全局的平衡点,也就是纳什(Nash)均衡点。

​ 由上式可知,给定任意生成器 G G G,判别器 D D D 的训练标准都是最大化 V ( D , G ) V(D, G) V(D,G),那么, 第一条公式还可以表示为:
V ( D , G ) = ∫ x P data  ( x ) log ⁡ ( D ( x ) ) d x + ∫ x p z ( z ) log ⁡ ( 1 − D ( g ( z ) ) ) d z = ∫ x [ P data  ( x ) log ⁡ ( D ( x ) ) + P g ( x ) log ⁡ ( 1 − D ( x ) ) ] d x \begin{aligned} V(D, G) & =\int_x P_{\text {data }}(x) \log (D(x)) d x+\int_x p_z(z) \log (1-D(g(z))) d z \\ & =\int_x\left[P_{\text {data }}(x) \log (D(x))+P_g(x) \log (1-D(x))\right] d x \end{aligned} V(D,G)=xPdata (x)log(D(x))dx+xpz(z)log(1D(g(z)))dz=x[Pdata (x)log(D(x))+Pg(x)log(1D(x))]dx
​ 已知形如 alog ⁡ ( y ) + b log ⁡ ( 1 − y ) \operatorname{alog}(y)+b \log (1-y) alog(y)+blog(1y) 的函数, 当 y ∈ [ 0 , 1 ] y \in[0,1] y[0,1] 时, a , b a, b a,b 为非零实数时, 可以在 a a + b \frac{a}{a+b} a+ba 处取得最小值,所以当生成器网络的参数固定不变训练判别器时,判别器网络的当前全局最优值可以通过下面公式获得:
D G ∗ ( x ) = P data  ( x ) P data  ( x ) + P g ( x ) D_G^*(x)=\frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_g(x)} DG(x)=Pdata (x)+Pg(x)Pdata (x)
​ 判别器最优时训练函数表示如下:
max ⁡ D V ( D , G ) = E x ∼ P data  ( x ) [ log ⁡ ( D G ∗ ( x ) ) ] + E z ∼ P z ( z ) [ log ⁡ ( 1 − D G ∗ ( G ( z ) ) ) ] = E x ∼ P data  ( x ) [ log ⁡ ( D G ∗ ( x ) ) ] + E x ∼ P g ( x ) [ log ⁡ ( 1 − D G ∗ ( x ) ) ] = E x ∼ P data  ( x ) [ log ⁡ P data  ( x ) P data  ( x ) + P g ( x ) ] + E x ∼ P g ( x ) [ log ⁡ P data  ( x ) P data  ( x ) + P g ( x ) ] \begin{aligned} \max _D V(D, G) & =E_{x \sim P_{\text {data }}(x)}\left[\log \left(D_G^*(x)\right)\right]+E_{z \sim P_z(z)}\left[\log \left(1-D_G^*(G(z))\right)\right] \\ & =E_{x \sim P_{\text {data }}(x)}\left[\log \left(D_G^*(x)\right)\right]+E_{x \sim P_g(x)}\left[\log \left(1-D_G^*(x)\right)\right] \\ & =E_{x \sim P_{\text {data }}(x)}\left[\log \frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_g(x)}\right]+E_{x \sim P_g(x)}\left[\log \frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_g(x)}\right] \end{aligned} DmaxV(D,G)=ExPdata (x)[log(DG(x))]+EzPz(z)[log(1DG(G(z)))]=ExPdata (x)[log(DG(x))]+ExPg(x)[log(1DG(x))]=ExPdata (x)[logPdata (x)+Pg(x)Pdata (x)]+ExPg(x)[logPdata (x)+Pg(x)Pdata (x)]
​ 当且仅当 p g = p data p_g=p_\text{data} pg=pdata时,此时的真实样本数 据与生成数据的分布相同, V ( D , G ) V(D, G) V(D,G)取得最小值 − l o g 4 -log4 log4,这时生成器就能生成无法同真实样本区分的“真实”数据,判别器已经无法再区分它们之间的差别。

训练代码

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

# 创建保存生成图片的目录
os.makedirs("images", exist_ok=True)

# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="训练的轮数")
parser.add_argument("--batch_size", type=int, default=64, help="每批的样本数量")
parser.add_argument("--lr", type=float, default=0.0002, help="Adam优化器的学习率")
parser.add_argument("--b1", type=float, default=0.5, help="Adam优化器的梯度一阶矩的衰减系数")
parser.add_argument("--b2", type=float, default=0.999, help="Adam优化器的梯度二阶矩的衰减系数")
parser.add_argument("--n_cpu", type=int, default=8, help="用于生成批次的CPU线程数")
parser.add_argument("--latent_dim", type=int, default=100, help="潜在空间的维度")
parser.add_argument("--img_size", type=int, default=28, help="每个图像维度的大小")
parser.add_argument("--channels", type=int, default=1, help="图像的通道数")
parser.add_argument("--sample_interval", type=int, default=400, help="生成图像样本的间隔")
opt = parser.parse_args()
print(opt)

# 定义图像形状
img_shape = (opt.channels, opt.img_size, opt.img_size)

# 检查是否支持CUDA加速
cuda = True if torch.cuda.is_available() else False

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

# 定义鉴别器模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

# 定义对抗网络的损失函数
adversarial_loss = torch.nn.BCELoss()

# 初始化生成器和鉴别器
generator = Generator()
discriminator = Discriminator()

# 如果支持CUDA,则将模型移至GPU
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# 配置数据加载器
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# 定义优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# 定义张量类型
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  训练
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # 定义对抗网络的真实标签和生成标签
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # 配置输入
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  训练生成器
        # -----------------

        optimizer_G.zero_grad()

        # 生成器输入的噪声样本
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # 生成一批图像
        gen_imgs = generator(z)

        # 生成器的损失,度量其欺骗鉴别器的能力
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  训练鉴别器
        # ---------------------

        optimizer_D.zero_grad()

        # 鉴别器的损失,度量其区分真实图像和生成图像的能力
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        # 每隔一定间隔保存生成的图像样本
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25

你可能感兴趣的:(深度学习,计算机视觉,人工智能)