生成对抗网络(GAN)生成MNIST数据

GAN是一种基于博弈论的生成式网络,它是2014年有 Ian Goodfellow提出,主要解决的是如何从训练样本中学习出新的样本。

其中GAN既不依赖标签来优化,也不是根据奖惩来调整参数,它是根据生成器和判别器之间的博弈来不断优化。

GAN的直观理解,可以想象一个名画伪造者想要伪造一副达芬奇的画作,开始时,伪造者的技术不精,但他将自己的一些赝品和达芬奇的作品混合在一起,请一个艺术商人对每一幅画的真实性进行评估,并向伪造者反馈,告诉他哪些看起来想真迹,哪些看起来不像真迹。
伪造者根据这些反馈,改进自己的赝品。随着时间的推移,伪造者的技术越来越高,艺术商人也越来越擅长找出赝品。最后,他们手上就由拥有了一些非常逼真的赝品

这就是GAN的基本原理,这里有两个角色,一个是伪造者,一个是技术鉴赏者,他们的目的都是打败对方

因此,GAN主要由两部分组成:
1)生成器网络:有一个潜在空间的随机向量作为输入,并将其解码为一幅图像题
2)判别器网络:以一张图像(可以是真的,可以是生成的)作为输入,并预测其来自于训练集还是生成器合成的

其结构图如下所示:
生成对抗网络(GAN)生成MNIST数据_第1张图片

我们使用的pytorch框架,对判别器与生成器进行训练,代码如下:

import torch
import torchvision
import os
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import optim
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_dir = "gan_images"
if not os.path.exists(image_dir):
    os.makedirs(image_dir)

#hyperparameter
image_size = 28 * 28
hidden_size = 400
latent_size = 100
batch_size = 100
epoch_num = 100

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5,0.5)])
data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset=data, shuffle=True, batch_size=batch_size)


D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,1),
    nn.Sigmoid()
)

G = nn.Sequential(

    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

criteria = nn.BCELoss()
d_optimizer = optim.Adam(params=D.parameters(), lr=0.0001)
g_optimizer = optim.Adam(params=G.parameters(), lr=0.0001)

D = D.to(device)
G = G.to(device)

total_step = len(data_loader)
for epoch in range(epoch_num):
    for i, (images, _) in enumerate(data_loader, 0):

        images = images.reshape(batch_size, -1).to(device)

        #训练判别器
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        outputs = D(images)
        d_loss_real = criteria(outputs, real_labels)

        z = torch.randn(batch_size, latent_size).to(device)
        outputs = D(G(z))
        d_loss_fake = criteria(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()


        #训练生成器
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criteria(outputs, real_labels)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if i % 200 == 0:
            print("Epoch [{} / {}],Step [{} / {}] d_loss = {:.4f}, g_loss = {:.4f}".format(epoch + 1, epoch_num,i, total_step, d_loss.item(), g_loss.item()))

    fake_images = fake_images.reshape(batch_size, 1, 28, 28)
    save_image(fake_images, os.path.join(image_dir, "gan_images-{}.png".format(epoch + 1)))


torch.save(G.state_dict(), "Generator.pth")
torch.save(D.state_dict(), "Discriminator.pth")




生成的图像会放置在gan_images文件夹下,训练过程的输出如下所示:
Epoch [1 / 100], d_loss = 1.4005, g_loss = 0.6798
Epoch [1 / 100], d_loss = 0.0621, g_loss = 3.5801
Epoch [1 / 100], d_loss = 0.2698, g_loss = 3.5682
Epoch [2 / 100], d_loss = 0.1811, g_loss = 3.5154
Epoch [2 / 100], d_loss = 0.3187, g_loss = 3.1001
Epoch [2 / 100], d_loss = 0.3439, g_loss = 2.7816
Epoch [3 / 100], d_loss = 1.3116, g_loss = 1.6221
Epoch [3 / 100], d_loss = 0.0951, g_loss = 4.1097
Epoch [3 / 100], d_loss = 0.2743, g_loss = 3.0359
Epoch [4 / 100], d_loss = 0.3159, g_loss = 2.6536
Epoch [4 / 100], d_loss = 0.0741, g_loss = 3.8504
Epoch [4 / 100], d_loss = 0.2342, g_loss = 3.4751
Epoch [5 / 100], d_loss = 0.1974, g_loss = 3.5920
Epoch [5 / 100], d_loss = 0.2368, g_loss = 3.2966
Epoch [5 / 100], d_loss = 0.1879, g_loss = 3.7705
Epoch [6 / 100], d_loss = 0.3655, g_loss = 3.9668
Epoch [6 / 100], d_loss = 0.2644, g_loss = 3.5875
Epoch [6 / 100], d_loss = 0.1215, g_loss = 3.8600
Epoch [7 / 100], d_loss = 0.1748, g_loss = 3.5253
Epoch [7 / 100], d_loss = 0.1105, g_loss = 3.8256
Epoch [7 / 100], d_loss = 0.0883, g_loss = 4.8804
Epoch [8 / 100], d_loss = 0.0891, g_loss = 6.2839
Epoch [8 / 100], d_loss = 0.0507, g_loss = 4.4470
Epoch [8 / 100], d_loss = 0.2816, g_loss = 4.0533
Epoch [9 / 100], d_loss = 0.2605, g_loss = 5.1191
Epoch [9 / 100], d_loss = 0.2144, g_loss = 3.8165
Epoch [9 / 100], d_loss = 0.2019, g_loss = 4.5753
Epoch [10 / 100], d_loss = 0.3241, g_loss = 3.3218
Epoch [10 / 100], d_loss = 0.3779, g_loss = 4.1461
Epoch [10 / 100], d_loss = 0.2594, g_loss = 3.6480
Epoch [11 / 100], d_loss = 0.3161, g_loss = 3.1858
Epoch [11 / 100], d_loss = 0.2094, g_loss = 3.6961
Epoch [11 / 100], d_loss = 0.2777, g_loss = 5.4229
Epoch [12 / 100], d_loss = 0.1255, g_loss = 4.2613
Epoch [12 / 100], d_loss = 0.4155, g_loss = 3.6585
Epoch [12 / 100], d_loss = 0.1338, g_loss = 5.9054
Epoch [13 / 100], d_loss = 0.3724, g_loss = 4.3520
Epoch [13 / 100], d_loss = 0.2869, g_loss = 3.2829
Epoch [13 / 100], d_loss = 0.2133, g_loss = 4.1905
Epoch [14 / 100], d_loss = 0.2999, g_loss = 2.7607
Epoch [14 / 100], d_loss = 0.3578, g_loss = 4.3012
Epoch [14 / 100], d_loss = 0.1794, g_loss = 4.8473
Epoch [15 / 100], d_loss = 0.3180, g_loss = 3.7902
Epoch [15 / 100], d_loss = 0.1475, g_loss = 5.3947
Epoch [15 / 100], d_loss = 0.1403, g_loss = 3.6762
Epoch [16 / 100], d_loss = 0.2932, g_loss = 3.4348
Epoch [16 / 100], d_loss = 0.1673, g_loss = 4.8111
Epoch [16 / 100], d_loss = 0.2986, g_loss = 4.8031
Epoch [17 / 100], d_loss = 0.2396, g_loss = 4.7224
Epoch [17 / 100], d_loss = 0.2053, g_loss = 4.3120
Epoch [17 / 100], d_loss = 0.1204, g_loss = 4.5077
Epoch [18 / 100], d_loss = 0.2661, g_loss = 5.1059
Epoch [18 / 100], d_loss = 0.2851, g_loss = 4.5754
Epoch [18 / 100], d_loss = 0.2176, g_loss = 5.4797
Epoch [19 / 100], d_loss = 0.1762, g_loss = 5.0284
Epoch [19 / 100], d_loss = 0.1988, g_loss = 4.0182
Epoch [19 / 100], d_loss = 0.1350, g_loss = 5.8161
Epoch [20 / 100], d_loss = 0.2850, g_loss = 5.0353
Epoch [20 / 100], d_loss = 0.3551, g_loss = 2.9458
Epoch [20 / 100], d_loss = 0.1139, g_loss = 4.3131
Epoch [21 / 100], d_loss = 0.1293, g_loss = 5.3033
Epoch [21 / 100], d_loss = 0.0710, g_loss = 5.8201
Epoch [21 / 100], d_loss = 0.1737, g_loss = 4.3040
Epoch [22 / 100], d_loss = 0.1824, g_loss = 3.8893
Epoch [22 / 100], d_loss = 0.1078, g_loss = 4.6621
Epoch [22 / 100], d_loss = 0.1687, g_loss = 4.8674
Epoch [23 / 100], d_loss = 0.1675, g_loss = 5.4555
Epoch [23 / 100], d_loss = 0.2172, g_loss = 4.6650
Epoch [23 / 100], d_loss = 0.2184, g_loss = 4.7891
Epoch [24 / 100], d_loss = 0.1516, g_loss = 6.1614
Epoch [24 / 100], d_loss = 0.1665, g_loss = 4.0922
Epoch [24 / 100], d_loss = 0.0997, g_loss = 5.2850
Epoch [25 / 100], d_loss = 0.3317, g_loss = 4.0964
Epoch [25 / 100], d_loss = 0.3065, g_loss = 4.0313
Epoch [25 / 100], d_loss = 0.2348, g_loss = 3.9021
Epoch [26 / 100], d_loss = 0.3027, g_loss = 4.8271
Epoch [26 / 100], d_loss = 0.2931, g_loss = 3.4956
Epoch [26 / 100], d_loss = 0.3160, g_loss = 6.4097
Epoch [27 / 100], d_loss = 0.3170, g_loss = 4.9517
Epoch [27 / 100], d_loss = 0.3300, g_loss = 5.1625
Epoch [27 / 100], d_loss = 0.2237, g_loss = 3.8433
Epoch [28 / 100], d_loss = 0.2623, g_loss = 3.8917
Epoch [28 / 100], d_loss = 0.2885, g_loss = 6.3781
Epoch [28 / 100], d_loss = 0.1820, g_loss = 4.7686
Epoch [29 / 100], d_loss = 0.1731, g_loss = 5.4061
Epoch [29 / 100], d_loss = 0.1728, g_loss = 6.1285
Epoch [29 / 100], d_loss = 0.4072, g_loss = 4.6868
Epoch [30 / 100], d_loss = 0.1897, g_loss = 4.2557
Epoch [30 / 100], d_loss = 0.2685, g_loss = 3.9201
Epoch [30 / 100], d_loss = 0.1200, g_loss = 5.9766
Epoch [31 / 100], d_loss = 0.2807, g_loss = 4.2405
Epoch [31 / 100], d_loss = 0.3687, g_loss = 3.5048
Epoch [31 / 100], d_loss = 0.1658, g_loss = 4.2518
Epoch [32 / 100], d_loss = 0.2290, g_loss = 5.4087
Epoch [32 / 100], d_loss = 0.2649, g_loss = 4.3162
Epoch [32 / 100], d_loss = 0.2490, g_loss = 5.2993
Epoch [33 / 100], d_loss = 0.2252, g_loss = 4.2555
Epoch [33 / 100], d_loss = 0.2828, g_loss = 4.4996
Epoch [33 / 100], d_loss = 0.2613, g_loss = 3.8531
Epoch [34 / 100], d_loss = 0.2915, g_loss = 5.0766
Epoch [34 / 100], d_loss = 0.2113, g_loss = 4.2702
Epoch [34 / 100], d_loss = 0.3235, g_loss = 3.2965
Epoch [35 / 100], d_loss = 0.2930, g_loss = 3.7018
Epoch [35 / 100], d_loss = 0.4591, g_loss = 3.6850
Epoch [35 / 100], d_loss = 0.4780, g_loss = 3.0125
Epoch [36 / 100], d_loss = 0.2119, g_loss = 3.7448
Epoch [36 / 100], d_loss = 0.3707, g_loss = 4.0800
Epoch [36 / 100], d_loss = 0.3892, g_loss = 3.2267
Epoch [37 / 100], d_loss = 0.3581, g_loss = 3.3309
Epoch [37 / 100], d_loss = 0.3379, g_loss = 3.7177
Epoch [37 / 100], d_loss = 0.4734, g_loss = 2.6390
Epoch [38 / 100], d_loss = 0.2984, g_loss = 2.9607
Epoch [38 / 100], d_loss = 0.4436, g_loss = 3.1230
Epoch [38 / 100], d_loss = 0.2265, g_loss = 3.7585
Epoch [39 / 100], d_loss = 0.4792, g_loss = 3.1574
Epoch [39 / 100], d_loss = 0.3267, g_loss = 3.8206
Epoch [39 / 100], d_loss = 0.3506, g_loss = 3.2502
Epoch [40 / 100], d_loss = 0.3300, g_loss = 4.0757
Epoch [40 / 100], d_loss = 0.3422, g_loss = 3.9873
Epoch [40 / 100], d_loss = 0.3272, g_loss = 3.8506
Epoch [41 / 100], d_loss = 0.4380, g_loss = 3.2595
Epoch [41 / 100], d_loss = 0.3456, g_loss = 3.1325
Epoch [41 / 100], d_loss = 0.4930, g_loss = 2.5089
Epoch [42 / 100], d_loss = 0.4362, g_loss = 3.1982
Epoch [42 / 100], d_loss = 0.3139, g_loss = 3.2215
Epoch [42 / 100], d_loss = 0.3479, g_loss = 2.4273
Epoch [43 / 100], d_loss = 0.3033, g_loss = 3.5534
Epoch [43 / 100], d_loss = 0.3428, g_loss = 3.2007
Epoch [43 / 100], d_loss = 0.4883, g_loss = 3.9243
Epoch [44 / 100], d_loss = 0.4175, g_loss = 3.3569
Epoch [44 / 100], d_loss = 0.5574, g_loss = 3.3162
Epoch [44 / 100], d_loss = 0.5981, g_loss = 2.5613
Epoch [45 / 100], d_loss = 0.4480, g_loss = 3.5681
Epoch [45 / 100], d_loss = 0.6369, g_loss = 3.0833
Epoch [45 / 100], d_loss = 0.4301, g_loss = 3.7558
Epoch [46 / 100], d_loss = 0.5784, g_loss = 2.2023
Epoch [46 / 100], d_loss = 0.4192, g_loss = 2.9009
Epoch [46 / 100], d_loss = 0.4688, g_loss = 2.1012
Epoch [47 / 100], d_loss = 0.4303, g_loss = 2.4235
Epoch [47 / 100], d_loss = 0.5326, g_loss = 2.5725
Epoch [47 / 100], d_loss = 0.5327, g_loss = 2.7579
Epoch [48 / 100], d_loss = 0.2568, g_loss = 3.3153
Epoch [48 / 100], d_loss = 0.5148, g_loss = 2.9677
Epoch [48 / 100], d_loss = 0.4422, g_loss = 2.4869
Epoch [49 / 100], d_loss = 0.5628, g_loss = 2.7539
Epoch [49 / 100], d_loss = 0.4655, g_loss = 2.4238
Epoch [49 / 100], d_loss = 0.5316, g_loss = 3.2095
Epoch [50 / 100], d_loss = 0.4632, g_loss = 3.4854
Epoch [50 / 100], d_loss = 0.4265, g_loss = 2.5565
Epoch [50 / 100], d_loss = 0.3153, g_loss = 3.5203
Epoch [51 / 100], d_loss = 0.6447, g_loss = 2.3761
Epoch [51 / 100], d_loss = 0.6159, g_loss = 2.5303
Epoch [51 / 100], d_loss = 0.5911, g_loss = 2.6402
Epoch [52 / 100], d_loss = 0.6623, g_loss = 2.7971
Epoch [52 / 100], d_loss = 0.6024, g_loss = 2.1936
Epoch [52 / 100], d_loss = 0.7278, g_loss = 2.6819
Epoch [53 / 100], d_loss = 0.5315, g_loss = 2.4350
Epoch [53 / 100], d_loss = 0.6086, g_loss = 2.1015
Epoch [53 / 100], d_loss = 0.6430, g_loss = 1.9207
Epoch [54 / 100], d_loss = 0.6781, g_loss = 2.5795
Epoch [54 / 100], d_loss = 0.5939, g_loss = 2.1577
Epoch [54 / 100], d_loss = 0.6015, g_loss = 2.1501
Epoch [55 / 100], d_loss = 0.6770, g_loss = 2.2563
Epoch [55 / 100], d_loss = 0.3797, g_loss = 2.9337
Epoch [55 / 100], d_loss = 0.6671, g_loss = 2.2520
Epoch [56 / 100], d_loss = 0.4809, g_loss = 1.9355
Epoch [56 / 100], d_loss = 0.6161, g_loss = 1.8455
Epoch [56 / 100], d_loss = 0.4937, g_loss = 2.6282
Epoch [57 / 100], d_loss = 0.6174, g_loss = 1.8371
Epoch [57 / 100], d_loss = 0.6436, g_loss = 2.5613
Epoch [57 / 100], d_loss = 0.7414, g_loss = 2.4130
Epoch [58 / 100], d_loss = 0.5596, g_loss = 2.3446
Epoch [58 / 100], d_loss = 0.4977, g_loss = 2.6491
Epoch [58 / 100], d_loss = 0.3481, g_loss = 2.7025
Epoch [59 / 100], d_loss = 0.6528, g_loss = 2.1182
Epoch [59 / 100], d_loss = 0.5975, g_loss = 2.1660
Epoch [59 / 100], d_loss = 0.4634, g_loss = 2.7633
Epoch [60 / 100], d_loss = 0.6411, g_loss = 2.2056
Epoch [60 / 100], d_loss = 0.7455, g_loss = 2.1277
Epoch [60 / 100], d_loss = 0.5846, g_loss = 2.3214
Epoch [61 / 100], d_loss = 0.4130, g_loss = 2.6413
Epoch [61 / 100], d_loss = 0.5043, g_loss = 2.7461
Epoch [61 / 100], d_loss = 0.5027, g_loss = 2.7873
Epoch [62 / 100], d_loss = 0.5225, g_loss = 2.8119
Epoch [62 / 100], d_loss = 0.5937, g_loss = 2.2405
Epoch [62 / 100], d_loss = 0.6624, g_loss = 2.0931
Epoch [63 / 100], d_loss = 0.5260, g_loss = 2.0877
Epoch [63 / 100], d_loss = 0.4278, g_loss = 2.4840
Epoch [63 / 100], d_loss = 0.7284, g_loss = 1.8716
Epoch [64 / 100], d_loss = 0.4951, g_loss = 1.9208
Epoch [64 / 100], d_loss = 0.5356, g_loss = 1.8108
Epoch [64 / 100], d_loss = 0.5486, g_loss = 2.5334
Epoch [65 / 100], d_loss = 0.6209, g_loss = 2.4244
Epoch [65 / 100], d_loss = 0.5548, g_loss = 2.1217
Epoch [65 / 100], d_loss = 0.7993, g_loss = 1.9576
Epoch [66 / 100], d_loss = 0.6239, g_loss = 2.3424
Epoch [66 / 100], d_loss = 0.7308, g_loss = 1.9419
Epoch [66 / 100], d_loss = 0.4720, g_loss = 2.6218
Epoch [67 / 100], d_loss = 0.7112, g_loss = 2.3217
Epoch [67 / 100], d_loss = 0.6114, g_loss = 2.4066
Epoch [67 / 100], d_loss = 0.7292, g_loss = 1.8385
Epoch [68 / 100], d_loss = 0.6430, g_loss = 2.1762
Epoch [68 / 100], d_loss = 0.6845, g_loss = 2.2916
Epoch [68 / 100], d_loss = 0.7470, g_loss = 1.6900
Epoch [69 / 100], d_loss = 0.8242, g_loss = 1.5990
Epoch [69 / 100], d_loss = 0.6194, g_loss = 2.1995
Epoch [69 / 100], d_loss = 0.7138, g_loss = 2.2843
Epoch [70 / 100], d_loss = 0.7451, g_loss = 1.7273
Epoch [70 / 100], d_loss = 0.6697, g_loss = 2.6672
Epoch [70 / 100], d_loss = 0.5215, g_loss = 2.3656
Epoch [71 / 100], d_loss = 0.7273, g_loss = 1.7724
Epoch [71 / 100], d_loss = 0.7882, g_loss = 2.0707
Epoch [71 / 100], d_loss = 0.8792, g_loss = 2.1253
Epoch [72 / 100], d_loss = 0.6391, g_loss = 1.8659
Epoch [72 / 100], d_loss = 0.5978, g_loss = 2.3649
Epoch [72 / 100], d_loss = 0.4909, g_loss = 2.1047
Epoch [73 / 100], d_loss = 0.6429, g_loss = 2.7016
Epoch [73 / 100], d_loss = 0.6738, g_loss = 2.3777
Epoch [73 / 100], d_loss = 0.6250, g_loss = 1.8886
Epoch [74 / 100], d_loss = 0.7762, g_loss = 1.9123
Epoch [74 / 100], d_loss = 0.6819, g_loss = 1.9295
Epoch [74 / 100], d_loss = 0.7057, g_loss = 2.3590
Epoch [75 / 100], d_loss = 0.7284, g_loss = 1.9105
Epoch [75 / 100], d_loss = 0.6256, g_loss = 1.8463
Epoch [75 / 100], d_loss = 0.5224, g_loss = 1.5472
Epoch [76 / 100], d_loss = 0.6944, g_loss = 1.8389
Epoch [76 / 100], d_loss = 0.6051, g_loss = 2.2255
Epoch [76 / 100], d_loss = 0.7042, g_loss = 1.4677
Epoch [77 / 100], d_loss = 0.6009, g_loss = 1.6764
Epoch [77 / 100], d_loss = 0.5430, g_loss = 2.2582
Epoch [77 / 100], d_loss = 0.5698, g_loss = 1.9660
Epoch [78 / 100], d_loss = 0.5518, g_loss = 2.4821
Epoch [78 / 100], d_loss = 0.7566, g_loss = 2.3749
Epoch [78 / 100], d_loss = 0.5386, g_loss = 2.1131
Epoch [79 / 100], d_loss = 0.6013, g_loss = 2.0279
Epoch [79 / 100], d_loss = 0.7488, g_loss = 1.7669
Epoch [79 / 100], d_loss = 0.6356, g_loss = 1.4959
Epoch [80 / 100], d_loss = 0.7612, g_loss = 1.6782
Epoch [80 / 100], d_loss = 0.6220, g_loss = 2.1223
Epoch [80 / 100], d_loss = 0.9306, g_loss = 1.9309
Epoch [81 / 100], d_loss = 0.5956, g_loss = 2.0545
Epoch [81 / 100], d_loss = 0.6892, g_loss = 1.6276
Epoch [81 / 100], d_loss = 0.7528, g_loss = 1.8048
Epoch [82 / 100], d_loss = 0.6198, g_loss = 1.7642
Epoch [82 / 100], d_loss = 0.7554, g_loss = 1.8449
Epoch [82 / 100], d_loss = 0.5543, g_loss = 1.7551
Epoch [83 / 100], d_loss = 0.7482, g_loss = 1.7634
Epoch [83 / 100], d_loss = 0.5848, g_loss = 2.4618
Epoch [83 / 100], d_loss = 0.6526, g_loss = 1.8014
Epoch [84 / 100], d_loss = 0.7416, g_loss = 2.2330
Epoch [84 / 100], d_loss = 0.7117, g_loss = 2.1474
Epoch [84 / 100], d_loss = 0.6715, g_loss = 1.5732
Epoch [85 / 100], d_loss = 0.6500, g_loss = 2.1351
Epoch [85 / 100], d_loss = 0.6907, g_loss = 1.5845
Epoch [85 / 100], d_loss = 0.6423, g_loss = 1.9624
Epoch [86 / 100], d_loss = 0.6468, g_loss = 2.2248
Epoch [86 / 100], d_loss = 0.7964, g_loss = 1.7578
Epoch [86 / 100], d_loss = 0.8308, g_loss = 2.1643
Epoch [87 / 100], d_loss = 0.8049, g_loss = 1.9290
Epoch [87 / 100], d_loss = 0.7781, g_loss = 1.7725
Epoch [87 / 100], d_loss = 0.6872, g_loss = 1.8172
Epoch [88 / 100], d_loss = 0.7625, g_loss = 1.6974
Epoch [88 / 100], d_loss = 0.7031, g_loss = 2.0148
Epoch [88 / 100], d_loss = 0.8161, g_loss = 2.3086
Epoch [89 / 100], d_loss = 0.5941, g_loss = 1.6677
Epoch [89 / 100], d_loss = 0.6878, g_loss = 1.6104
Epoch [89 / 100], d_loss = 0.5485, g_loss = 2.0137
Epoch [90 / 100], d_loss = 0.8351, g_loss = 1.6406
Epoch [90 / 100], d_loss = 0.7346, g_loss = 2.1254
Epoch [90 / 100], d_loss = 0.7062, g_loss = 2.0978
Epoch [91 / 100], d_loss = 0.8811, g_loss = 2.2355
Epoch [91 / 100], d_loss = 0.6416, g_loss = 2.0354
Epoch [91 / 100], d_loss = 0.6701, g_loss = 2.2361
Epoch [92 / 100], d_loss = 0.8181, g_loss = 1.5149
Epoch [92 / 100], d_loss = 0.6995, g_loss = 1.7809
Epoch [92 / 100], d_loss = 0.9302, g_loss = 1.6532
Epoch [93 / 100], d_loss = 0.7537, g_loss = 2.2442
Epoch [93 / 100], d_loss = 0.7595, g_loss = 1.5303
Epoch [93 / 100], d_loss = 0.6609, g_loss = 2.1224
Epoch [94 / 100], d_loss = 0.7008, g_loss = 1.6132
Epoch [94 / 100], d_loss = 0.8136, g_loss = 2.3829
Epoch [94 / 100], d_loss = 0.6874, g_loss = 1.7750
Epoch [95 / 100], d_loss = 0.7074, g_loss = 1.3372
Epoch [95 / 100], d_loss = 0.6983, g_loss = 1.7831
Epoch [95 / 100], d_loss = 0.6664, g_loss = 2.4457
Epoch [96 / 100], d_loss = 0.6281, g_loss = 2.2233
Epoch [96 / 100], d_loss = 0.6056, g_loss = 1.6816
Epoch [96 / 100], d_loss = 0.6814, g_loss = 1.8850
Epoch [97 / 100], d_loss = 0.6050, g_loss = 1.9024
Epoch [97 / 100], d_loss = 0.7528, g_loss = 1.6076
Epoch [97 / 100], d_loss = 0.6339, g_loss = 2.4748
Epoch [98 / 100], d_loss = 0.6732, g_loss = 2.0769
Epoch [98 / 100], d_loss = 0.6004, g_loss = 2.0239
Epoch [98 / 100], d_loss = 0.8085, g_loss = 1.9665
Epoch [99 / 100], d_loss = 0.7437, g_loss = 1.9152
Epoch [99 / 100], d_loss = 0.6655, g_loss = 1.8907
Epoch [99 / 100], d_loss = 0.6616, g_loss = 1.7388
Epoch [100 / 100], d_loss = 0.8149, g_loss = 1.6021
Epoch [100 / 100], d_loss = 0.8716, g_loss = 1.8420
Epoch [100 / 100], d_loss = 0.8323, g_loss = 1.6049

在迭代了100个Epoch后,生成的图像如下图所示:

生成对抗网络(GAN)生成MNIST数据_第2张图片

我们输出一张真实的图像来对比一下:
生成对抗网络(GAN)生成MNIST数据_第3张图片

可以看到生成的图像还是会有较多的噪声,但是可以看到里面的数字了

在完成了100个Epoch之后,我们使用torch.save()函数将网络保存了下来,后面可以直接使用该网络,
我们加载进生成器的参数,并给它一个随机变量,看它生成的结果是什么,代码如下所示:

from torch import nn
from torchvision.utils import  save_image
import torch
import os

image_size = 28 * 28
hidden_size = 400
latent_size = 100

image_dir = "fake_images"
if not os.path.exists(image_dir):
    os.makedirs(image_dir)

G = nn.Sequential(

    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

G.load_state_dict(torch.load("Generator.pth"))
z = torch.randn(1, latent_size)
fake_image = G(z)

fake_image = fake_image.reshape(1, 1, 28,28)
save_image(fake_image, os.path.join(image_dir, "fake_image.png"))



这里我们的生成器的输入只是一个隐变量,所以它只会生成一张图像,生成后保存在fake_images的文件夹下,生成的图像如下图所示:
在这里插入图片描述

可以看到在我们完成了生成其的训练之后,只要给它一个满足分布的随机变量,就可以生成一张我们想要的图像,是不是感觉非常的神奇。但是它有一个比较大的缺点就是我们自己无法控制生成的图像是什么。以MNIST数据集为例,当我们给它一个随机变量的时候,它生成的结果是随机的,也就是我们无法控制生成的数据是多少,如果我们想要控制生成的数字,可以使用GAN的变体,CGAN来生成图像)

当然这只是最基本的一个GAN模型,后面还有大量的GAN模型的变体来提升GAN的性能

你可能感兴趣的:(机器学习,python,计算机视觉,机器学习,深度学习)