前言:个人电脑算力有限(2G现显存GPU),现使用numpy自生成1维图像数据,dcgan生成对抗网络训练50轮(花费不到20min),生成fake image已肉眼难分。
生成结果如下:
导入
import torchvision
import torchvision.transforms as transform
from torchvision.datasets import ImageFolder
import torchvision.utils as vutils
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import BCELoss
from torch.optim import Adam,SGD
import matplotlib.pyplot as plt
from matplotlib import animation # 用于生成gif图像
import numpy as np
训练参数设置
# 参数设置
images_num = 3000 # 生成训练图像数量
image_size = 32 # 生成图像尺寸
image_channel = 1 # 生成图像通道数
batch_size = 128 # 训练数据批次
noise_size = 100 # 随机噪声分布向量长度,用于生成fake image
lr = 0.0002 # 学习率
epochs = 50 # 训练轮次
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
批量生成图像训练集
# 1 生成简易图像数据集
def gen_images():
image = np.ones(shape=(image_size,image_size))-np.random.rand(image_size,image_size)/10
min_x = np.random.choice(range(3,8))
max_x = np.random.choice(range(25,30))
values = np.random.rand(1)/10
for i in range(min_x,max_x):
row =i;col= int(np.power((i/max_x-0.55)*2,2)*max_x)+3
image[col,row]=values;image[col+2,row]=values;image[col+2,row]=values
return image
# 创建图像数据集
class Images:
def __init__(self):
self.len_ = images_num
def __getitem__(self,index):
return torch.from_numpy(gen_images()).unsqueeze(dim=0).float(),''
def __len__(self):
return self.len_
# 生产批数据图像
images = Images()
dataloader = DataLoader(dataset=images,batch_size=batch_size,shuffle=True,drop_last=True)
# 生成图像显示
real_batch = next(iter(dataloader))
fig,axs = plt.subplots()
fig.set_size_inches(w=12,h=8)
axs.imshow(vutils.make_grid(tensor=real_batch[0][:64],nrow=8,padding=2).permute(1,2,0))
axs.axis('off');axs.axis('off')
构建网络
# 创建生成网络
class G_net(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(in_channels=noise_size,out_channels=image_size*16,kernel_size=4,stride=1,padding=0,bias=False)
,nn.BatchNorm2d(image_size*16)
,nn.ReLU(inplace=True)
,nn.ConvTranspose2d(in_channels=image_size*16,out_channels=image_size*8,kernel_size=4,stride=2,padding=1,bias=False)
,nn.BatchNorm2d(image_size*8)
,nn.ReLU(inplace=True)
,nn.ConvTranspose2d(in_channels=image_size*8,out_channels=image_size*4,kernel_size=4,stride=2,padding=1,bias=False)
,nn.BatchNorm2d(image_size*4)
,nn.ReLU(inplace=True)
,nn.ConvTranspose2d(in_channels=image_size*4,out_channels=image_channel,kernel_size=4,stride=2,padding=1,bias=False)
,nn.Sigmoid()
)
def forward(self,x):
return self.main(x)
# 创建鉴别网络,设置位生成卷积的逆过程
class D_net(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channels=image_channel,out_channels=image_size*2,kernel_size=4,stride=2,padding=0,bias=False)
,nn.LeakyReLU(negative_slope=0.2,inplace=True)
,nn.Conv2d(in_channels=image_size*2,out_channels=image_size*4,kernel_size=4,stride=2,padding=0,bias=False)
,nn.BatchNorm2d(image_size*4)
,nn.LeakyReLU(negative_slope=0.2,inplace=True)
,nn.Conv2d(in_channels=image_size*4,out_channels=image_size*8,kernel_size=4,stride=2,padding=0,bias=False)
,nn.BatchNorm2d(image_size*8)
,nn.LeakyReLU(negative_slope=0.2,inplace=True)
,nn.Conv2d(in_channels=image_size*8,out_channels=1,kernel_size=2,stride=1,padding=0,bias=False)
,nn.Sigmoid()
)
def forward(self,x):
return self.main(x)
网络参数初始化
# 参数初始化函数
def weight_init(m):
class_name = m.__class__.__name__
if class_name.find('Conv') != -1:
nn.init.normal_(m.weight.data,mean=0,std=0.02)
elif class_name.find('Batch') != -1:
nn.init.normal_(m.weight.data,mean=1,std=0.02)
nn.init.constant_(m.bias.data,val=0)
设置优化器和损失函数
# 设置损失函数、优化器
G = G_net().to(device)
D = D_net().to(device)
G.apply(weight_init)
D.apply(weight_init)
loss = BCELoss()
G_opt = Adam(params=G.parameters(),lr=lr,betas=(0.5,0.999))
D_opt = Adam(params=D.parameters(),lr=lr,betas=(0.5,0.999))
# G_opt = SGD(params=G.parameters(),lr=lr)
# D_opt = SGD(params=D.parameters(),lr=lr)
训练
# 训练
# 目标函数损失值
loss_list_d = []
loss_list_g = []
# D网络预测正负样本的概率,目的是生成真假混淆图片,概率值应该越来越接近0.5(鉴别器真假难分)
outputs_real_d_list = []
outputs_fake_d_list = []
outputs_g_list = []
# G网络训练过程生成图像记录
g_gen_images = []
for epoch in range(epochs):
for real_x,_ in dataloader:
# discriminator训练
D_opt.zero_grad()
# 鉴别真实数据
real_label = torch.ones(size=(batch_size,)).to(device)
outputs_real_d = D(real_x.to(device))
loss_real_D = loss(outputs_real_d.squeeze(),real_label)
loss_real_D.backward()
# 鉴别生成的假数据
fake_x = G(torch.randn(size=(batch_size,noise_size,1,1)).to(device))
fake_label = torch.zeros(size=(batch_size,)).to(device)
outputs_fake_d = D(fake_x.detach())
loss_fake_D = loss(outputs_fake_d.squeeze(),fake_label)
loss_fake_D.backward()
# 鉴别器总损失
D_loss = loss_real_D + loss_fake_D
D_opt.step()
G_opt.zero_grad()
fake_x = G(torch.randn(size=(batch_size,noise_size,1,1)).to(device))
fake_label = torch.ones(size=(batch_size,)).to(device)
outputs_g = D(fake_x)
loss_fake_G = loss(outputs_g.squeeze(),fake_label)
loss_fake_G.backward()
G_opt.step()
if epoch % 5 == 0:
loss_list_d.append(D_loss.item())
loss_list_g.append(loss_fake_G.item())
outputs_real_d_list.append(outputs_real_d.mean().item())
outputs_fake_d_list.append(outputs_fake_d.mean().item())
outputs_g_list.append(outputs_g.mean().item())
print(f'第{epoch}次Dloss: {D_loss}')
print(f'第{epoch}次Gloss: {loss_fake_G}')
g_gen_images.append(fake_x)
# torch.save(g_gen_images,'./g_gen_images.pkl')
训练结果展示
# 查看最终生成数据结果
real_batch = next(iter(dataloader))
fake_x = G(torch.rand(size=(batch_size,noise_size,1,1)).to(device)).cpu()
fig,axs = plt.subplots(1,2)
fig.set_size_inches(w=12,h=8)
axs[0].imshow(torchvision.utils.make_grid(tensor=real_batch[0][:64],nrow=8,padding=2).permute(1,2,0))
axs[1].imshow(torchvision.utils.make_grid(tensor=fake_x[:64],nrow=8,padding=2).permute(1,2,0))
axs[0].axis('off');axs[1].axis('off')
axs[0].set_title('real_images');axs[1].set_title('fake_images')