基于PyTorch的生成对抗网络入门(5)——利用PyTorch搭建条件生成对抗网络(CGAN)超详解

目录

  • 一、案例描述
  • 二、代码详解
    • 2.1 构建真实数据集
    • 2.2 创建生成虚假数据
    • 2.3 构建判别器
      • 2.3.1 创建生成虚假标签
      • 2.3.2 判别器类
      • 2.3.3 测试判别器
    • 2.4 构建生成器
      • 2.4.1 生成器类
      • 2.4.2 测试生成器
    • 2.5 训练GAN
    • 2.6 测试生成器的生成效果

一、案例描述

本文介绍利用条件生成对抗网络CGAN)进行手写数字生成。比如说,要求生成不同的,但是都代表某一个数字的图像。

二、代码详解

2.1 构建真实数据集

加载数据集利用datasets.MNIST()函数,定义加载数据集函数:

#定义加载数据集函数
def load_data_mnist(batch_size):
    '''下载MNIST数据集然后加载到内存中'''
    train_dataset=datasets

你可能感兴趣的:(PyTorch,python,pytorch,生成对抗网络,深度学习,python,神经网络)