pip install torch torchvision matplotlib numpy
GAN就像艺术品鉴定师与伪造者的博弈:
生成器(Generator):伪造者
判别器(Discriminator):鉴定师
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 超参数
lr = 0.0002
batch_size = 64
image_size = 64
channels_img = 1
z_dim = 100 # 噪声向量维度
num_epochs = 20
# 数据预处理
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # 归一化到[-1, 1]
])
# 下载并加载MNIST数据集
dataset = datasets.MNIST(
root="dataset/",
train=True,
transform=transform,
download=True
)
# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
class Generator(nn.Module):
def __init__(self, z_dim, channels_img, features_g):
super(Generator, self).__init__()
self.gen = nn.Sequential(
# 输入: z_dim x 1 x 1
self._block(z_dim, features_g * 16, 4, 1, 0), # 4x4
self._block(features_g * 16, features_g * 8, 4, 2, 1), # 8x8
self._block(features_g * 8, features_g * 4, 4, 2, 1), # 16x16
self._block(features_g * 4, features_g * 2, 4, 2, 1), # 32x32
nn.ConvTranspose2d(features_g * 2, channels_img, 4, 2, 1), # 64x64
nn.Tanh(), # 输出值在[-1,1]之间
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.gen(x)
class Discriminator(nn.Module):
def __init__(self, channels_img, features_d):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
# 输入: channels_img x 64 x 64
nn.Conv2d(channels_img, features_d, 4, 2, 1), # 32x32
nn.LeakyReLU(0.2),
self._block(features_d, features_d * 2, 4, 2, 1), # 16x16
self._block(features_d * 2, features_d * 4, 4, 2, 1), # 8x8
self._block(features_d * 4, features_d * 8, 4, 2, 1), # 4x4
nn.Conv2d(features_d * 8, 1, 4, 2, 0), # 1x1
nn.Sigmoid(), # 输出概率值
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2),
)
def forward(self, x):
return self.disc(x)
# 初始化模型
gen = Generator(z_dim, channels_img, 64).to(device)
disc = Discriminator(channels_img, 64).to(device)
# 初始化权重
def weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
gen.apply(weights_init)
disc.apply(weights_init)
# 优化器
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
# 损失函数
criterion = nn.BCELoss()
# 固定噪声用于可视化训练进展
fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)
for epoch in range(num_epochs):
for batch_idx, (real, _) in enumerate(dataloader):
real = real.to(device)
batch_size = real.shape[0]
# 训练判别器
noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
fake = gen(noise)
# 判别器对真实图像的预测
disc_real = disc(real).view(-1)
loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
# 判别器对生成图像的预测
disc_fake = disc(fake.detach()).view(-1)
loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
# 判别器总损失
loss_disc = (loss_disc_real + loss_disc_fake) / 2
disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# 训练生成器
output = disc(fake).view(-1)
loss_gen = criterion(output, torch.ones_like(output))
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# 每100个batch打印一次
if batch_idx % 100 == 0:
print(
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} "
f"Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
)
# 每轮保存生成图像示例
with torch.no_grad():
fake = gen(fixed_noise)
img_grid = torchvision.utils.make_grid(fake[:32], normalize=True)
plt.imshow(np.transpose(img_grid.cpu(), (1, 2, 0)))
plt.axis("off")
plt.savefig(f"output/epoch_{epoch}.png")
plt.close()
项目结构:
dcgan_mnist/
├── train.py # 训练脚本
├── models.py # 模型定义
├── utils.py # 辅助函数
├── requirements.txt
├── output/ # 输出目录
└── README.md
requirements.txt
内容:
torch>=2.0.0
torchvision>=0.15.0
matplotlib>=3.7.0
numpy>=1.24.0
Epoch [0/20] Batch 0/938 Loss D: 0.6931, loss G: 0.6978
Epoch [0/20] Batch 100/938 Loss D: 0.0123, loss G: 5.4321
...
Epoch [20/20] Batch 800/938 Loss D: 0.5123, loss G: 1.2314
解决方案:
batch_size
)平衡技巧:
稳定训练方法: