基于pytorch的dcgan代码实现,进行简易图像数据生成

简易实现使用dcgan进行图像数据生成

前言:个人电脑算力有限(2G现显存GPU),现使用numpy自生成1维图像数据,dcgan生成对抗网络训练50轮(花费不到20min),生成fake image已肉眼难分。

生成结果如下

基于pytorch的dcgan代码实现,进行简易图像数据生成_第1张图片

  1. 导入

    import torchvision
    import torchvision.transforms as transform
    from torchvision.datasets import ImageFolder
    import torchvision.utils as vutils
    
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader
    from torch.nn import BCELoss
    from torch.optim import Adam,SGD
    
    
    import matplotlib.pyplot as plt
    from matplotlib import animation # 用于生成gif图像
    import numpy as np
    
  2. 训练参数设置

    # 参数设置
    images_num = 3000 # 生成训练图像数量
    image_size = 32 # 生成图像尺寸
    image_channel = 1 # 生成图像通道数
    
    batch_size = 128 # 训练数据批次
    noise_size = 100 # 随机噪声分布向量长度,用于生成fake image
    
    lr = 0.0002 # 学习率
    epochs = 50 # 训练轮次
    device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
  3. 批量生成图像训练集

    # 1 生成简易图像数据集
    def gen_images():
        
        image = np.ones(shape=(image_size,image_size))-np.random.rand(image_size,image_size)/10
        min_x = np.random.choice(range(3,8))
        max_x = np.random.choice(range(25,30))
        values = np.random.rand(1)/10
        for i in range(min_x,max_x):
            row =i;col= int(np.power((i/max_x-0.55)*2,2)*max_x)+3
            image[col,row]=values;image[col+2,row]=values;image[col+2,row]=values
        return image
    
    # 创建图像数据集
    class Images:
        
        def __init__(self):
            self.len_ = images_num
        
        def __getitem__(self,index):
            return torch.from_numpy(gen_images()).unsqueeze(dim=0).float(),''
        
        def __len__(self):
            return self.len_
    
    # 生产批数据图像  
    images = Images()
    dataloader = DataLoader(dataset=images,batch_size=batch_size,shuffle=True,drop_last=True)
    
    # 生成图像显示
    real_batch = next(iter(dataloader))
    fig,axs = plt.subplots()
    fig.set_size_inches(w=12,h=8)
    axs.imshow(vutils.make_grid(tensor=real_batch[0][:64],nrow=8,padding=2).permute(1,2,0))
    axs.axis('off');axs.axis('off')
    

基于pytorch的dcgan代码实现,进行简易图像数据生成_第2张图片

  1. 构建网络

    # 创建生成网络
    class G_net(nn.Module):
        
        def __init__(self):
            super().__init__()
            self.main = nn.Sequential(
                           nn.ConvTranspose2d(in_channels=noise_size,out_channels=image_size*16,kernel_size=4,stride=1,padding=0,bias=False)
                           ,nn.BatchNorm2d(image_size*16)
                           ,nn.ReLU(inplace=True)
                
                           ,nn.ConvTranspose2d(in_channels=image_size*16,out_channels=image_size*8,kernel_size=4,stride=2,padding=1,bias=False)
                           ,nn.BatchNorm2d(image_size*8)
                           ,nn.ReLU(inplace=True)
                
                           ,nn.ConvTranspose2d(in_channels=image_size*8,out_channels=image_size*4,kernel_size=4,stride=2,padding=1,bias=False)
                           ,nn.BatchNorm2d(image_size*4)
                           ,nn.ReLU(inplace=True)
                           
                           ,nn.ConvTranspose2d(in_channels=image_size*4,out_channels=image_channel,kernel_size=4,stride=2,padding=1,bias=False)
                           ,nn.Sigmoid()
            
                            )
        
        def forward(self,x):
            return self.main(x)
            
            
    # 创建鉴别网络,设置位生成卷积的逆过程
    class D_net(nn.Module):
        
        def __init__(self):
            super().__init__()
            self.main = nn.Sequential(
                                      nn.Conv2d(in_channels=image_channel,out_channels=image_size*2,kernel_size=4,stride=2,padding=0,bias=False) 
                                      ,nn.LeakyReLU(negative_slope=0.2,inplace=True)
                                      
                                      ,nn.Conv2d(in_channels=image_size*2,out_channels=image_size*4,kernel_size=4,stride=2,padding=0,bias=False)
                                      ,nn.BatchNorm2d(image_size*4) 
                                      ,nn.LeakyReLU(negative_slope=0.2,inplace=True)
                                      
                                      ,nn.Conv2d(in_channels=image_size*4,out_channels=image_size*8,kernel_size=4,stride=2,padding=0,bias=False)
                                      ,nn.BatchNorm2d(image_size*8) 
                                      ,nn.LeakyReLU(negative_slope=0.2,inplace=True)
                                         
                                      ,nn.Conv2d(in_channels=image_size*8,out_channels=1,kernel_size=2,stride=1,padding=0,bias=False)
                                      ,nn.Sigmoid()
                                             )
            
        def forward(self,x):
            return self.main(x)
    
  2. 网络参数初始化

    # 参数初始化函数
    def weight_init(m):
        class_name = m.__class__.__name__
        if class_name.find('Conv') != -1:
            nn.init.normal_(m.weight.data,mean=0,std=0.02)
        elif class_name.find('Batch') != -1:
            nn.init.normal_(m.weight.data,mean=1,std=0.02)
            nn.init.constant_(m.bias.data,val=0)
        
    
  3. 设置优化器和损失函数

    # 设置损失函数、优化器
    G = G_net().to(device)
    D = D_net().to(device)
    G.apply(weight_init)
    D.apply(weight_init)
    
    loss = BCELoss()
    
    G_opt = Adam(params=G.parameters(),lr=lr,betas=(0.5,0.999))
    D_opt = Adam(params=D.parameters(),lr=lr,betas=(0.5,0.999))
    
    # G_opt = SGD(params=G.parameters(),lr=lr)
    # D_opt = SGD(params=D.parameters(),lr=lr)
    
  4. 训练

    # 训练
    # 目标函数损失值
    loss_list_d = []
    loss_list_g = []
    # D网络预测正负样本的概率,目的是生成真假混淆图片,概率值应该越来越接近0.5(鉴别器真假难分)
    outputs_real_d_list = []
    outputs_fake_d_list = []
    outputs_g_list = []
    
    # G网络训练过程生成图像记录
    g_gen_images = []
    
    for epoch in range(epochs):
    
        for real_x,_ in dataloader:
            # discriminator训练
            D_opt.zero_grad()
            # 鉴别真实数据
            real_label = torch.ones(size=(batch_size,)).to(device)
            outputs_real_d = D(real_x.to(device))
            loss_real_D = loss(outputs_real_d.squeeze(),real_label)
            loss_real_D.backward()
            
            # 鉴别生成的假数据
            fake_x = G(torch.randn(size=(batch_size,noise_size,1,1)).to(device))
            fake_label = torch.zeros(size=(batch_size,)).to(device)
            outputs_fake_d = D(fake_x.detach())
            loss_fake_D = loss(outputs_fake_d.squeeze(),fake_label)
            loss_fake_D.backward()
            # 鉴别器总损失
            D_loss =  loss_real_D + loss_fake_D
            
            D_opt.step()
            
        
            G_opt.zero_grad()
    
            fake_x = G(torch.randn(size=(batch_size,noise_size,1,1)).to(device))
            fake_label = torch.ones(size=(batch_size,)).to(device)
            outputs_g = D(fake_x)
            loss_fake_G = loss(outputs_g.squeeze(),fake_label)
            loss_fake_G.backward()
    
            G_opt.step()
        
        if epoch % 5 == 0:
                loss_list_d.append(D_loss.item())
                loss_list_g.append(loss_fake_G.item())
    
                outputs_real_d_list.append(outputs_real_d.mean().item())
                outputs_fake_d_list.append(outputs_fake_d.mean().item())
                outputs_g_list.append(outputs_g.mean().item())
        
                print(f'第{epoch}次Dloss: {D_loss}')
                print(f'第{epoch}次Gloss: {loss_fake_G}')
                
                g_gen_images.append(fake_x)
        
    # torch.save(g_gen_images,'./g_gen_images.pkl')  
    
  5. 训练结果展示

    # 查看最终生成数据结果
    real_batch = next(iter(dataloader))
    fake_x = G(torch.rand(size=(batch_size,noise_size,1,1)).to(device)).cpu()
    fig,axs = plt.subplots(1,2)
    fig.set_size_inches(w=12,h=8)
    axs[0].imshow(torchvision.utils.make_grid(tensor=real_batch[0][:64],nrow=8,padding=2).permute(1,2,0))
    axs[1].imshow(torchvision.utils.make_grid(tensor=fake_x[:64],nrow=8,padding=2).permute(1,2,0))
    axs[0].axis('off');axs[1].axis('off')
    axs[0].set_title('real_images');axs[1].set_title('fake_images')
    

基于pytorch的dcgan代码实现,进行简易图像数据生成_第3张图片

你可能感兴趣的:(深度学习,pytorch,深度学习,python)