基于GAN实现MNIST图片生成-----Pytorch版

GAN生成手写数字识别

生成对抗网络介绍:

生成对抗网络就是两个网络进行相互对抗,相互进行的过程。GAN的主要部分为生成器Generator和判别器Discriminator。
生成器:输入一个随机向量,生成一个图片。希望生成的图片越像真的越好。
判别器:输出一个图片,判别图片的真伪。希望生成的图片判别为假,数据集中的图片判别为真。

判别器:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1=nn.Linear(28*28,256)
        self.lr1=nn.LeakyReLU(0.2)
        self.fc2=nn.Linear(256,256)
        self.lr2=nn.LeakyReLU(0.2)
        self.fc3=nn.Linear(256,1)
        self.sigmoid=nn.Sigmoid()

    def forward(self,input):
        x=input.view(-1,28*28)
        x=self.fc1(x)
        x=self.lr1(x)
        x=self.fc2(x)
        x=self.lr2(x)
        x=self.fc3(x)
        y=self.sigmoid(x)
        return y

生成器:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1=nn.Linear(100,256)
        self.relu1=nn.ReLU()
        self.fc2=nn.Linear(256,256)
        self.relu2=nn.ReLU()
        self.fc3=nn.Linear(256,28*28)
        self.tanh=nn.Tanh()

    def forward(self,input):
        x=self.fc1(input)
        x=self.relu1(x)
        x=self.fc2(x)
        x=self.relu2(x)
        x=self.fc3(x)
        y=self.tanh(x)
        return y

训练判别器:

real_out=D(real_img)
d_loss_real=criterion(real_out,real_label)
fake_img=G(z).detach()
fake_out=D(fake_img)
d_loss_fake=criterion(fake_out,fake_label)
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

训练生成器:

fake_img=G(z)
output=D(fake_img)
g_loss=criterion(output,real_label)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

全部代码

import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1=nn.Linear(28*28,256)
        self.lr1=nn.LeakyReLU(0.2)
        self.fc2=nn.Linear(256,256)
        self.lr2=nn.LeakyReLU(0.2)
        self.fc3=nn.Linear(256,1)
        self.sigmoid=nn.Sigmoid()

    def forward(self,input):
        x=input.view(-1,28*28)
        x=self.fc1(x)
        x=self.lr1(x)
        x=self.fc2(x)
        x=self.lr2(x)
        x=self.fc3(x)
        y=self.sigmoid(x)
        return y

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1=nn.Linear(100,256)
        self.relu1=nn.ReLU()
        self.fc2=nn.Linear(256,256)
        self.relu2=nn.ReLU()
        self.fc3=nn.Linear(256,28*28)
        self.tanh=nn.Tanh()

    def forward(self,input):
        x=self.fc1(input)
        x=self.relu1(x)
        x=self.fc2(x)
        x=self.relu2(x)
        x=self.fc3(x)
        y=self.tanh(x)
        return y

def to_img(x):
    out=0.5*(x+1)
    out=out.clamp(0,1)
    out=out.view(-1,28,28,1)
    return out

def save_images(fake_image,epoch):
    r, c = 5, 5
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(fake_image[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/epoch_{}.png".format(epoch+1))
    plt.close()

if __name__=='__main__':
    epochs=100
    batch_size=128
    z_dimension=100
    transformer=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,),(0.5,))
    ])
    mnist=torchvision.datasets.MNIST(root='data',train=True,transform=transformer,download=True)
    dataloader=DataLoader(dataset=mnist,batch_size=batch_size,shuffle=True)
    D=Discriminator()
    G=Generator()

    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    D = D.to(device)
    G = G.to(device)
    criterion=nn.BCELoss()
    d_optimizer=torch.optim.Adam(D.parameters(),lr=0.0003)
    g_optimizer=torch.optim.Adam(G.parameters(),lr=0.0003)
    for epoch in range(epochs):
        for idx,(img,_) in enumerate(dataloader):
            num_img=img.size(0)
            img=img.view(num_img,-1)
            real_img=Variable(img)
            real_label=Variable(torch.ones(num_img,1))
            z=Variable(torch.randn(num_img,z_dimension))
            fake_label=Variable(torch.zeros(num_img,1))

            real_img=real_img.to(device)
            real_label=real_label.to(device)
            fake_label=fake_label.to(device)
            z=z.to(device)
            '''
            训练判别器
            '''
            real_out=D(real_img)
            d_loss_real=criterion(real_out,real_label)
            fake_img=G(z).detach()
            fake_out=D(fake_img)
            d_loss_fake=criterion(fake_out,fake_label)
            d_loss=d_loss_real+d_loss_fake

            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()
            '''
            训练生成器
            '''
            z=Variable(torch.randn(num_img,z_dimension))
            z=z.to(device)
            fake_img=G(z)
            output=D(fake_img)
            g_loss=criterion(output,real_label)
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            if idx %100==0:
                print('epoch:{} ,d_loss:{} ,g_loss:{}'.format(epoch,d_loss.item(),g_loss.item()))
        fake_image=to_img(fake_img.cpu().data)
        save_images(fake_image,epoch)
    print('--------------------------')
    torch.save(G.state_dict(),'generator.pth')
    torch.save(D.state_dict(),'discriminator.pth')

训练500epoch的结果:

基于GAN实现MNIST图片生成-----Pytorch版_第1张图片

你可能感兴趣的:(GAN,pytorch,深度学习,神经网络)