生成对抗网络(Generative Adversarial Networks,GANs)是一种深度学习模型,由Ian Goodfellow等人在2014年提出。GAN由两个神经网络组成:生成器(Generator)和判别器(Discriminator),它们通过对抗训练的方式进行优化。以下是详细介绍:
生成器(G):生成器试图生成逼真的假样本(如图像),其输入通常是随机噪声(如高斯噪声或均匀噪声)。生成器的目标是通过学习训练数据的分布来生成与真实数据尽可能相似的样本。
判别器(D):判别器的任务是区分真实数据和生成器生成的假数据。它的输入是一个样本,输出是真实样本的概率。
GAN的训练过程可以看作是一个双人零和博弈:
生成器的目标:生成器G希望生成的样本能够欺骗判别器D,使其认为这些假样本是真实的。
判别器的目标:判别器D希望能够准确地区分真实样本和生成样本,尽可能提高识别真实样本的概率,同时降低对假样本的识别概率。
生成器和判别器通过对抗训练不断提高自身能力。具体来说:
GAN的损失函数包括两个部分:生成器的损失和判别器的损失。
判别器的损失函数:
生成器的损失函数:
以下是一个使用PyTorch实现简单GAN的示例:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, output_size),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_size, output_size),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
# 参数设置
batch_size = 64
learning_rate = 0.0002
num_epochs = 200
input_size = 100
hidden_size = 256
image_size = 784 # 28x28
output_size = 1
# 数据集加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# 实例化生成器和判别器
G = Generator(input_size, hidden_size, image_size)
D = Discriminator(image_size, hidden_size, output_size)
# 损失函数和优化器
criterion = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
# 训练GAN
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 准备数据
images = images.view(-1, image_size)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 训练判别器
outputs = D(images)
D_loss_real = criterion(outputs, real_labels)
real_score = outputs
z = torch.randn(batch_size, input_size)
fake_images = G(z)
outputs = D(fake_images)
D_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
D_loss = D_loss_real + D_loss_fake
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# 训练生成器
z = torch.randn(batch_size, input_size)
fake_images = G(z)
outputs = D(fake_images)
G_loss = criterion(outputs, real_labels)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
if (i+1) % 200 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], D_loss: {D_loss.item()}, G_loss: {G_loss.item()}, D(x): {real_score.mean().item()}, D(G(z)): {fake_score.mean().item()}')
# 保存生成的图像
import matplotlib.pyplot as plt
z = torch.randn(batch_size, input_size)
fake_images = G(z)
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
fig, ax = plt.subplots(8, 8, figsize=(8, 8))
for i in range(8):
for j in range(8):
ax[i, j].imshow(fake_images[i*8 + j, 0].detach().cpu().numpy(), cmap='gray')
ax[i, j].axis('off')
plt.show()
GAN有许多实际应用,包括但不限于:
优势:
挑战: