G1周-生成对抗网络入门

 本文为365天深度学习训练营 中的学习记录博客
 原作者:K同学啊|接辅导、项目定制

我的环境:

1.语言:python3.7

2.编译器:pycharm

3.深度学习框架Pytorch 1.8.0+cu111


GAN网络介绍:

生成对抗网络(GAN)是一种深度学习模型,最初由Ian Goodfellow等人于2014年提出。GAN的核心思想是通过训练两个神经网络,一个生成器(Generator)和一个判别器(Discriminator),彼此博弈来学习生成与真实数据相似的样本。

  1. 生成器(Generator):

    • 生成器的任务是从随机噪声中生成逼真的数据样本。它接收一个随机向量(通常称为潜在空间的点)作为输入,并输出一个与真实数据相似的样本。生成器通过不断学习调整其参数,使其生成的样本越来越逼真。
  2. 判别器(Discriminator):

    • 判别器的任务是区分生成器生成的样本和真实数据。它接收两种类型的输入:一种是来自生成器的生成样本,另一种是来自真实数据集的真实样本。判别器通过学习识别哪些样本是真实的,哪些是生成的,不断提高自己的准确性。
  3. 对抗过程:

    • 在训练过程中,生成器和判别器相互竞争、相互博弈。生成器努力生成更逼真的样本,而判别器努力提高对生成样本和真实样本的辨别能力。这种对抗的过程最终导致生成器生成的样本越来越难以被判别器区分。
  4. 目标函数:

    • GAN的训练过程通过最小化生成器和判别器之间的对抗损失来实现。生成器希望最大化判别器无法区分生成样本和真实样本的概率,而判别器则希望最小化这个概率。

GAN的优点包括能够生成高质量、逼真的数据,无需显式规定生成规则。它在图像生成、风格迁移、图像超分辨率等任务上取得了显著的成功。然而,GAN的训练也面临一些挑战,如模式崩溃、训练不稳定等,需要仔细的调参和设计。

GAN网络整体示意图如下:
G1周-生成对抗网络入门_第1张图片

 结合整体模型图示,再以生成图片作为例子具体说明下面。我们有两个网络,G(Generator)和D(Discriminator)。Generator是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。Discriminator是一个判别网络,判别一张图片是不是“真实的”。它的输入是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

一、超参设置

import argparse
import os
import numpy as np
import torch.cuda
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import  DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch

os.makedirs("./images/",exist_ok = True) #记录训练过程的图片效果
os.makedirs("./save/",exist_ok = True) #训练完成时模型保存的位置
os.makedirs("./datasets/mnist",exist_ok = True) #下载数据集存放的位置

#超参数配置
n_epochs = 50
batch_size = 64
Ir = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 2
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500

img_shape = (channels,img_size,img_size) #图像尺寸(1,28,28)
img_area = np.prod(img_shape) #图像像素面积784

cuda = True if torch.cuda.is_available() else False
print(cuda)

二、下载数据集

mnist = datasets.MNIST(
    root = './datasets/',train = True,download =True,transform = transforms.Compose(
     [transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]),)

 三、模型定义

1、鉴别器

#鉴别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area,512),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(256,1),
            nn.Sigmoid(),
        )

    def forward(self,img):
        img_flat = img.view(img.size(0),-1)
        validity = self.model(img_flat)
        return validity

  • 判别器是一个神经网络模型,它接受一个图像作为输入并输出一个标量值,表示输入图像是真实图像的概率。通过使用nn.Sequential定义了一个包含线性层、LeakyReLU激活函数和Sigmoid激活函数的模型。

  • 输入图像经过一系列线性变换和非线性激活,最终通过 Sigmoid 激活函数输出一个范围在 [0, 1] 内的值,表示输入图像是真实图像的概率。

  • 定义了名为 `Discriminator` 的类,它继承自 PyTorch 的 `nn.Module` 类。super(Discriminator,self).__init__() 调用了父类的构造函数
  • 在 `__init__` 方法中,定义了一个名为 `model` 的神经网络模型,使用了 `nn.Sequential` 封装了一系列层。
  • nn.Linear(img_area,512):线性层,从输入大小为 `img_area` 的数据映射到输出大小为 512。 nn.LeakyReLU(0.2,inplace=True):Leaky ReLU 激活函数,引入了一个小的负斜率,有助于处理梯度消失问题。
  • nn.Linear(512,256)` 和 `nn.LeakyReLU(0.2,inplace=True)`:另一组线性层和 Leaky ReLU 激活。
  • nn.Linear(256,1):最终线性层,输出一个单一的值,用于二分类(真实或伪造)。 - nn.Sigmoid():最后一层使用 Sigmoid 激活函数,将输出值压缩到 (0, 1) 范围,表示判别为真实的概率。
  • forward 方法定义了数据在模型中前向传播的流程。
  • img.view(img.size(0),-1):将输入图像展平为一维张量,保留 batch 的维度。
  • self.model(img_flat):将展平后的图像输入到定义的神经网络模型中。
  • return validity:返回模型的输出,即表示判别为真实的概率。

2、生成器

##生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()

        def block(in_feat,out_feat,normalize = True):
            layers = [nn.Linear(in_feat,out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace = True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim,128,normalize = False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,img_area),
            nn.Tanh()
        )

    def forward(self,z):
        imgs = self.model(z)
        imgs = imgs.view(imgs.size(0),*img_shape)
        return imgs
  • Generator 类同样继承自 nn.Module
  • 定义了一个辅助函数 block,用于创建生成器中的一个块(block)。这个块由一个线性层、一个批归一化(如果需要)、和 Leaky ReLU 激活函数组成。
  • __init__ 方法中,通过堆叠多个块构建了整个生成器模型。
  • *block(latent_dim,128,normalize=False):输入层,从潜在向量(latent vector)映射到大小为 128 的隐藏层,不进行批归一化。
  • 后续的块,逐渐增加输出特征的维度,最终通过一个线性层将特征映射到输出大小为 img_area 的一维向量。
  • nn.Tanh():最后一层使用 Tanh 激活函数,将输出限制在 (-1, 1) 范围内。
  • forward 方法定义了数据在模型中前向传播的流程。
  • self.model(z):将输入潜在向量 z 通过生成器模型。
  • imgs.view(imgs.size(0), *img_shape):将生成器的输出展平为图像形状。
  • return imgs:返回生成的图像。

 四、训练模型

1、创建实例

generator = Generator()
discriminator = Discriminator()

criterion = torch.nn.BCELoss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=Ir, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = Ir,betas = (b1,b2))

if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()

这两个模型是 GAN 中的关键组件,通过对抗训练的方式使得生成器生成逼真的图像,而鉴别器则尽可能准确地判别真实和生成的图像。在训练循环中,生成器和鉴别器通过对抗损失函数进行优化,以实现更好的生成图像和更准确的判别。 

2、训练模型

for epoch in range(n_epochs):
    for i,(imgs,_) in enumerate(dataloader):

        imgs = imgs.view(imgs.size(0),-1)
        real_img = Variable(imgs).cuda()
        real_label = Variable(torch.ones(imgs.size(0),1)).cuda()
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()

        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out,real_label)
        real_scores = real_out

        z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out,fake_label)
        fake_scores = fake_out

        loss_D = loss_real_D + loss_fake_D
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z)
        fake_out = discriminator(fake_img)
        loss_G = criterion(fake_out, real_label)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        if (i + 1) % 300 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i ,len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25],"./images/%d.png" % batches_done, nrow = 5, normalize = True)

#模型保存
torch.save(generator.state_dict(),'./save/generator.pth')
torch.save(discriminator.state_dict(),'./save/discriminator.pth')

[Epoch 0/50] [Batch 299/938] [D loss: 1.062147] [G loss: 0.793500] [D real: 0.557045] [D fake: 0.361011]
[Epoch 0/50] [Batch 599/938] [D loss: 1.059478] [G loss: 1.188343] [D real: 0.587030] [D fake: 0.359636]
[Epoch 0/50] [Batch 899/938] [D loss: 1.042897] [G loss: 1.466038] [D real: 0.653879] [D fake: 0.433240]
[Epoch 1/50] [Batch 299/938] [D loss: 0.843557] [G loss: 1.561677] [D real: 0.771608] [D fake: 0.421925]
[Epoch 1/50] [Batch 599/938] [D loss: 1.026156] [G loss: 2.394096] [D real: 0.862107] [D fake: 0.558821]
[Epoch 1/50] [Batch 899/938] [D loss: 0.836007] [G loss: 1.557716] [D real: 0.652867] [D fake: 0.245031]
[Epoch 2/50] [Batch 299/938] [D loss: 1.035034] [G loss: 0.812236] [D real: 0.550502] [D fake: 0.262485]
[Epoch 2/50] [Batch 599/938] [D loss: 1.063006] [G loss: 1.113460] [D real: 0.483566] [D fake: 0.111593]
[Epoch 2/50] [Batch 899/938] [D loss: 0.803735] [G loss: 1.973702] [D real: 0.765450] [D fake: 0.374167]
[Epoch 3/50] [Batch 299/938] [D loss: 0.825772] [G loss: 1.667717] [D real: 0.780164] [D fake: 0.376694]
[Epoch 3/50] [Batch 599/938] [D loss: 1.435523] [G loss: 0.494344] [D real: 0.369774] [D fake: 0.139426]
[Epoch 3/50] [Batch 899/938] [D loss: 0.917040] [G loss: 0.905851] [D real: 0.552047] [D fake: 0.157987]
[Epoch 4/50] [Batch 299/938] [D loss: 1.141817] [G loss: 2.367955] [D real: 0.860573] [D fake: 0.589422]
[Epoch 4/50] [Batch 599/938] [D loss: 0.996989] [G loss: 0.822017] [D real: 0.590626] [D fake: 0.270388]
[Epoch 4/50] [Batch 899/938] [D loss: 0.753533] [G loss: 1.717019] [D real: 0.742206] [D fake: 0.312312]
[Epoch 5/50] [Batch 299/938] [D loss: 0.989197] [G loss: 0.888657] [D real: 0.634080] [D fake: 0.284923]
[Epoch 5/50] [Batch 599/938] [D loss: 0.844832] [G loss: 1.811573] [D real: 0.747003] [D fake: 0.359605]
[Epoch 5/50] [Batch 899/938] [D loss: 0.694102] [G loss: 1.491613] [D real: 0.720356] [D fake: 0.237716]
 

你可能感兴趣的:(生成对抗网络,人工智能,神经网络)