生成对抗网络(GAN):从概念到代码实践(附代码)

第一章:计算机视觉中图像的基础认知
第二章:计算机视觉:卷积神经网络(CNN)基本概念(一)
第三章:计算机视觉:卷积神经网络(CNN)基本概念(二)
第四章:搭建一个经典的LeNet5神经网络(附代码)
第五章:计算机视觉:神经网络实战之手势识别(附代码)
第六章:计算机视觉:目标检测从简单到容易(附代码)
第七章:MTCNN 人脸检测技术揭秘:原理、实现与实战(附代码)
第八章:探索YOLO技术:目标检测的高效解决方案
第九章:计算机视觉:主流数据集整理
第十章:生成对抗网络(GAN):从概念到代码实践(附代码)

一、GAN网络简介

生成对抗网络(GANs,Generative Adversarial Networks)是一种深度学习模型,通过两个神经网络——生成器和鉴别器之间的博弈来生成数据。

对抗学习的核心在于两个初始能力有限的模型(“两个弱鸡”):

  • 生成器(Generator):试图伪造数据
  • 鉴别器(Discriminator):负责检测伪造数据

二者初始表现都不理想,但通过对抗过程实现共同进化。这种设计突破传统单一模型优化的范式,创造性地引入动态博弈机制。

二、对抗学习基础

  • 前提:在GAN中,我们有两个“弱鸡”——生成器和鉴别器。它们都是初学者,在开始时都不太强大。
  • 过程:通过一种博弈的方式,即互啄(对抗),双方每次进步一点点。这个过程不是零和博弈,而是共同成长的过程。
  • 特点:GAN中的两个模型并不是互相抵消的,而是一起成长,生成器变得越来越擅长生成逼真的图像,而鉴别器则变得更擅长区分真假图像。

博弈过程解析

  • 训练过程可类比艺术品鉴定场景:
  • 造假者(生成器)初期制作粗糙赝品
  • 鉴定师(鉴别器)学习识别赝品特征
  • 造假者改进技术制作更逼真的赝品
  • 鉴定师被迫提升鉴别能力
  • 循环往复直至达到纳什均衡

核心特点

  • 非零和博弈:双方通过对抗实现共同提升
  • 渐进式优化:通过反向传播实现参数微调
  • 无监督特征学习:自动发现数据分布特征
  • 动态平衡:最终达到生成器≈鉴别器的状态

三、任务描述

我们的任务是给定一堆原始图像,训练一个模型来模仿这些图像并生成类似的图像。这通过使用对抗学习的方法实现:

  • 生成器:负责根据输入的噪声向量生成假图像。
  • 鉴别器:负责判断一张图像是真实的还是由生成器生成的假图像。
GAN
生成器Generator
判别器Discriminator
输入:随机噪声向量
输出:生成样本
输入:真实/生成样本
输出:判别概率

任务定义
给定MNIST手写数字数据集:

  • 输入:28x28灰度图像
  • 目标:生成逼真的手写数字图像

四、GAN代码实现

以下是基于PyTorch的一个简单GAN实现示例,用于生成类似MNIST数据集的手写数字图像:
生成对抗网络(GAN):从概念到代码实践(附代码)_第1张图片

4.1 导入依赖

import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib import gridspec
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
from IPython import display

导入必要的库,包括PyTorch及其扩展torchvision用于数据处理和加载,以及matplotlib用于可视化结果。

4.2 数据加载

# 检查是否有可用的GPU,并设置运行设备为GPU或CPU。
device = "cuda" if torch.cuda.is_available() else "cpu"
"""
    图像预处理
    定义一个转换操作,将图像从原始格式转换为张量,并归一化到[-1, 1]范围。
"""
trans = transforms.Compose(transforms=[
    # 原始输入图像形状是[0, 255] [H, W, C]
    # 转换之后的形状是[0, 1] [C, H, W]
    transforms.ToTensor(),
    # [0, 1]
    # [-1, 1]
    transforms.Normalize(mean=[0.5], std=[0.5])
])
# 加载图像
# 加载MNIST训练数据集,并创建一个数据加载器,设置批次大小为128,打乱数据顺序。
data = datasets.MNIST(root="data", train=True, transform=trans, download=True)
dataloader = DataLoader(dataset=data, batch_size=128, shuffle=True)
# 测试图像
for batch_X, batch_y in dataloader:
    print(batch_X.shape) # 输出 torch.Size([128, 1, 28, 28])
    print(batch_y.shape) # 输出 torch.Size([128])
    #加载出图片,需要讲格式为[C,H,W],转换成为PIL图片格式[H,W,C]
    image = transforms.ToPILImage()(batch_X[0])
    print(image)
    plt.figure(figsize=(2, 2))
    plt.imshow(image)
    break

生成对抗网络(GAN):从概念到代码实践(附代码)_第2张图片

代码解释:
上面代码创建一个data对象,它包含了整个MNIST训练集,并且每个图像都经过了预处理(转换为张量并进行了归一化)。

  • datasets.MNIST: 这是来自torchvision.datasets模块的一个类,专门用于加载MNIST手写数字数据集。

    • root="data": 指定下载的数据集存放的目录。如果指定的目录下没有数据集文件,将会自动下载到该目录。
    • train=True: 表示加载训练集。如果设置为False,则加载测试集。
    • transform=trans: 应用在每个图像上的转换操作。这里使用了之前定义的trans,它包括将图像转换为张量并进行归一化处理。
    • download=True: 如果本地不存在数据集,则自动从网络下载。
  • DataLoader: 来自torch.utils.data模块,用来遍历整个数据集,每次返回一个批次的数据(默认情况下是随机排序的),可以方便地以小批次的形式访问数据集,这对于训练过程中的内存管理和加速训练至关重要。

    • dataset=data: 指定了要加载的数据集,即前面定义的data对象。
    • batch_size=128: 每个批次包含的样本数量。这意味着每次迭代时,DataLoader会返回128个样本。
    • shuffle=True: 在每个epoch开始前是否打乱数据集。这对于训练模型来说非常重要,因为它有助于提高模型的泛化能力,防止模型学习到数据的特定顺序。

4.3 创建生成器

定义一个生成器(Generator)类,用于生成类似MNIST手写数字的新图像,接受一个随机噪声向量(默认维度为128)作为输入,并通过一系列线性变换和非线性激活函数将其转换为一个28x28像素的图像。

"""
    定义一个生成器
        - 造假图像的网络,
        - 生成器的作用是输入一个随机tensor张量,生成一个新的tensor张量,而这个新的张量代表的是一张图片
        输入:一个128维的向量,这个向量是一个随机数向量,128维也是自定义的,也可以是其他维
        输出:一个28*28维的向量,28*28是一个手写数字图片的大小
"""
class Generator(nn.Module):
    # 784
    def __init__(self, in_features=128, out_features=28 * 28):
    	# 调用父类nn.Module的构造函数,确保所有必要的初始化步骤都已完成。
        super().__init__()
        self.linear1 = nn.Linear(in_features=in_features, out_features=256)
        self.linear2 = nn.Linear(in_features=256, out_features=512)
        self.linear3 = nn.Linear(in_features=512, out_features=28 * 28)
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        x = torch.relu(x)
        x = self.linear3(x)
        x = torch.tanh(x)
        return x

代码解释:

  • Generator(nn.Module):该类继承自PyTorch的nn.Module,是构建神经网络模型的基础类。
  • self.linear1, self.linear2, self.linear3: 定义了三层全连接层(也称为线性层或密集层),它们构成了生成器的核心部分:
    • 第一层将输入维度从128扩展到256。
    • 第二层进一步扩展到512维。
    • 最后一层将维度压缩回到784(即28x28像素的图像尺寸),这与MNIST数据集中的图像大小相匹配。
  • forward: 定义了前向传播过程,这是在进行推理或训练时实际使用的计算流程。
    • 输入x首先通过第一层全连接层linear1,然后应用ReLU激活函数以引入非线性。
    • 接着,输出经过第二层全连接层linear2和另一个ReLU激活函数。
    • 然后,通过第三层全连接层linear3,这次使用的是tanh激活函数,将输出范围限制在[-1, 1]之间,这与之前对输入数据的归一化处理相对应。
    • 最终,返回变换后的张量x作为生成器的输出,其形状为(batch_size, 784),对应于一批批的28x28像素图像。

定义了鉴别器模型,负责判断输入图像是真实的还是由生成器生成的假图像。该类同样继承自PyTorch的nn.Module。接受一个28x28像素的图像(展平为784维向量)作为输入,并通过一系列线性变换和ReLU激活函数将其转换为两个值的输出。这两个输出可以用来表示输入图像属于“真实”类别和“伪造”类别的概率。

4.4 创建鉴别器

"""
    鉴别器
        - 判断真假!
        给定一个张量,代表一张图片,判断这个图片是真图片还是假图片
        输入:一个28*28的张量,格式为[C,H,W]
        输出:2个数字,一个代表真图片,一个代表假图片
"""
class Discriminator(nn.Module):
    def __init__(self, in_features=28 * 28, out_features=2):
    	# 调用了父类nn.Module的构造函数,确保所有必要的初始化步骤都已完成
        super().__init__()
        self.linear1 = nn.Linear(in_features=in_features, out_features=512)
        self.linear2 = nn.Linear(in_features=512, out_features=256)
        self.linear3 = nn.Linear(in_features=256, out_features=32)
        self.linear4 = nn.Linear(in_features=32, out_features=out_features)
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        x = torch.relu(x)
        x = self.linear3(x)
        x = torch.relu(x)
        x = self.linear4(x)
        return x

代码解释:

  • self.linear1, self.linear2, self.linear3, self.linear4: 定义了四层全连接层(线性层),它们构成了鉴别器的核心部分:
    • 第一层将输入维度从784(即28x28像素的展平图像)扩展到512。
    • 接下来的两层分别将维度减少到256和32。
    • 最后一层将维度压缩到2,二分类问题的输出(例如:真实图像 vs 假图像)。
  • forward: 定义了前向传播过程,这是在进行推理或训练时实际使用的计算流程。
    • 输入x首先通过第一层全连接层linear1,然后应用ReLU激活函数以引入非线性。
    • 接着,输出经过第二层全连接层linear2和另一个ReLU激活函数。
    • 然后,通过第三层全连接层linear3,再次应用ReLU激活函数。
    • 最终,通过第四层全连接层linear4,这次没有应用激活函数,因为对于二分类任务,我们通常会在损失函数中使用softmax或sigmoid来计算概率分布。
    • 返回变换后的张量x作为鉴别器的输出,其形状为(batch_size, 2),对应于一批图像是否为真的概率分布。

4.5 初始化生成器、鉴别器

初始化生成器和鉴别器,并将其移动到选定的设备上。定义了两个Adam优化器分别用于生成器和鉴别器,并设置了交叉熵损失函数。

"""
    构建模型
"""
# 定义一个生成器
generator = Generator(in_features=128, out_features=28 * 28)
generator.to(device=device)
# 定义一个鉴别器
discriminator = Discriminator(in_features=28 * 28, out_features=2)
discriminator.to(device=device)
"""
    优化器
        - 执行梯度下降
        - 2个网络,2个优化器
"""
g_optimizer = torch.optim.Adam(params=generator.parameters(), lr=1e-4)
d_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=1e-4)
"""
    定义一个损失函数
"""
loss_fn = nn.CrossEntropyLoss()
"""
    训练样本构建:
        - 真实图像:1
            - 来自于待模仿的对象
                - 手写数字
        - 假图像:0
            - 生成器生成的结果
                - 输入随机数,生成的假图像
"""
def get_real_data_labels(size):
    """
        获取真实数据的标签
        函数返回一个大小为size的张量,所有元素都为1,表示这些数据是真实的图像。
        假设size=5,输出labels:tensor([1, 1, 1, 1, 1])
    """
    labels = torch.ones(size, device=device, dtype=torch.long)
    return labels
    
def get_fake_data_labels(size):
    """
        获取假数据的标签
        函数返回一个大小为size的张量,所有元素都为0,表示这些数据是由生成器生成的假图像。
        假设size=5,输出labels:tensor([0, 0, 0, 0, 0])
    """
    labels = torch.zeros(size, device=device, dtype=torch.long)
    return labels
"""
    随机获取一批随机数
    函数用于生成一批随机噪声向量,每条向量的维度为128,数量由size决定。这是生成器用来生成新图像的输入。
"""
def get_noise(size):
    X = torch.randn(size, 128, device=device)
    return X

4.6 训练生成器、鉴别器

循环实现GAN的训练过程。每一轮中,首先训练鉴别器识别真假图像,然后训练生成器产生更逼真的图像。最后,每轮训练结束后,使用生成器生成一些样本并展示出来以监控训练进度。

"""
    训练过程
"""
# 获取一批测试噪声 num_test_samples = 16
num_test_samples = 16
test_noise = get_noise(num_test_samples)
# 模型将要训练的总轮次(epochs),即整个数据集将被遍历200次。
epochs = 200
# 外层循环遍历每一个epoch。
for epoch in range(1, epochs + 1):
    print(f"当前正在进行第{epoch}轮训练")
    # 逐批次训练
    # 内层循环通过dataloader逐批次加载数据集中的图像和标签。
    for batch_idx, (batch_X, batch_y) in enumerate(dataloader):
        # 1. 真实图像 将真实图像展平并移动到指定设备上。
        real_data = batch_X.view(batch_X.size(0), -1).to(device=device)
        # 2. 给生成器准备的输入噪声 
        # 生成与该批次大小相同的噪声向量
        noise = get_noise(batch_X.size(0)).to(device=device)
        # 使用生成器生成对应的假图像。得到一个常量假数据
        # .detach()方法确保计算图中不包含这部分操作,使得反向传播时不会更新生成器的参数。
        fake_data = generator(noise).detach()
        
        """
            互相踩在对方的肩上,一点点的进步!!!!!
        """
        """
            优化一步鉴别器
        """
        # 1. 训练过程
        # 清空鉴别器优化器的梯度。
        d_optimizer.zero_grad()
        # 把真实数据交给鉴别器去鉴别
        real_pred = discriminator(real_data)
        # 鉴别器对真实数据进行预测,并计算损失。
        real_loss = loss_fn(real_pred, get_real_data_labels(real_data.size(0)))
        real_loss.backward()
        # 对于生成的数据也执行相同的操作。
        # 反向传播误差,并更新鉴别器的参数。
        fake_pred = discriminator(fake_data)
        fake_loss = loss_fn(fake_pred, get_fake_data_labels(fake_data.size(0)))
        fake_loss.backward()
        # 鉴别器优化一步
        d_optimizer.step()

        """
            训练一步生成器
        """
        noise = get_noise(size=real_data.size(0))
        fake_pred = generator(noise)
        # 清空生成器优化器的梯度。
        g_optimizer.zero_grad()
        # 暂时冻结鉴别器的参数,防止其在生成器更新过程中被修改。
        for param in discriminator.parameters():
            param.requires_grad=False
        # 生成器生成新的图像,并让鉴别器评估这些图像。
        d_pred = discriminator(fake_pred)
        # 解冻鉴别器参数
        for param in discriminator.parameters():
            param.requires_grad=True
        # 互相博弈的关键
        # 计算生成器的损失并反向传播误差,更新生成器的参数。
        g_loss = loss_fn(d_pred, get_real_data_labels(d_pred.size(0)))
        g_loss.backward()
        g_optimizer.step()
    
    # 每训练一轮,看看训练结果,结果可视化
    # 生成器基于测试噪声生成一些图像,并展示出来以便直观地检查生成器的表现。使用matplotlib库绘制图像,并调整显示设置以获得清晰的可视化效果。
    with torch.no_grad():
        
        # 正向推理
        img_pred = generator(test_noise)
        img_pred = img_pred.view(img_pred.size(0), 28, 28).cpu().data
        
        # 画图
        display.clear_output(wait=True)
        
        # 设置画图的大小
        fig = plt.figure(1, figsize=(12, 8)) 
        # 划分为 4 x 4 的 网格
        gs = gridspec.GridSpec(4, 4)
        
        # 遍历每一个
        for i in range(4):
            for j in range(4):
                # 取每一个图
                X = img_pred[i * 4 + j, :, :]
                # 添加一个对应网格内的子图
                ax = fig.add_subplot(gs[i, j])
                # 在子图内绘制图像
                ax.matshow(X, cmap=plt.get_cmap("Greys"))
                ax.set_xticks(())
                ax.set_yticks(())
        plt.show()

输出结果,看看效果:
当前正在进行第3轮训练,从结果上看,生成器生成的数字图片啥也不是。
生成对抗网络(GAN):从概念到代码实践(附代码)_第3张图片
当前正在进行第33轮训练,从结果上看,生成器生成的数字图片初见成效。
生成对抗网络(GAN):从概念到代码实践(附代码)_第4张图片
当前正在进行第100轮训练,从结果上看,生成器生成的数字图片越来越逼真。
生成对抗网络(GAN):从概念到代码实践(附代码)_第5张图片
200 轮训练完成之后的效果
生成对抗网络(GAN):从概念到代码实践(附代码)_第6张图片

这段逻辑的核心思想是通过生成对抗网络(GAN)的训练机制,让生成器和鉴别器互相博弈以提升各自的性能。具体来说,这段代码实现了以下思想:

训练的核心思想

  1. 对抗学习:生成器尝试生成尽可能逼真的图像来欺骗鉴别器;而鉴别器则试图准确地区分真实图像和生成器生成的假图像。两者在这个过程中相互促进、共同进步。

  2. 逐步优化

    • 鉴别器的优化:在每个批次中,首先使用当前的真实图像和由生成器生成的假图像分别对鉴别器进行训练。这一步的目标是提高鉴别器区分真假图像的能力。
    • 生成器的优化:接着,通过冻结鉴别器参数的方式,让生成器基于随机噪声生成图像,并试图使这些图像能够“骗过”鉴别器。这个过程的目标是提高生成器生成逼真图像的能力。
  3. 测试与监控:每轮训练结束后,利用固定的测试噪声输入生成器,生成一批图像并展示出来。这样做可以帮助我们直观地观察生成器随训练进程的进步情况,以及模型最终的生成效果。

训练的实现细节

  • 准备阶段:定义了总的训练轮次epochs,以及用于测试生成器性能的一批固定噪声test_noise

  • 训练循环:外层循环遍历每一个epoch,内层循环通过数据加载器逐批次处理数据集中的图像和标签。

  • 数据处理:将真实图像展平并移动到指定设备上;同时为生成器准备与该批次大小相同的随机噪声向量。

  • 鉴别器训练:通过计算鉴别器对真实和假图像的预测损失,并执行反向传播更新鉴别器参数,旨在增强其分辨能力。

  • 生成器训练:生成新图像并通过鉴别器评估这些图像的质量,根据结果调整生成器的参数,目标是使其能生成更难被鉴别器识别出来的图像。

  • 可视化检查:每轮训练结束时,生成并展示一批由生成器生成的图像,以便于直观评估生成器的表现。

这种训练方式使得生成器和鉴别器在相互竞争的过程中不断改进自身,最终达到生成高质量图像的目的。

五、总结

GAN 网络通过对抗学习的方法训练生成器和鉴别器,使生成器能够生成与真实数据分布相似的样本。对抗学习的思想不仅可以应用于 GAN 网络,还可以融合到其他模型设计中,提高模型的性能和泛化能力。

深刻理解对抗生成的逻辑

  • 如何在对抗中去进步的 :生成器和鉴别器通过博弈的方式不断进步,生成器试图生成更逼真的假图像来欺骗鉴别器,而鉴别器则试图更准确地鉴别图像的真假。
  • 这种思想很重要 :对抗学习的思想可以应用于其他模型设计中,提高模型的性能。
  • 这个模型不重要 :GAN 网络的具体实现可以有很多种,关键在于理解其对抗学习的思想。
  • 把对抗学习的思想,融合到其它模型设计中,可以提高模型的性能和泛化能力。

你可能感兴趣的:(人工智能,计算机视觉,人工智能,GAN,网络,对抗学习,手势识别,生成器与鉴别器,生成对抗网络)