torch.nn.parallel.DistributedDataParallel
Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel
相比 DataParallel
具有以下优势:
ddp_demo.py
)import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScaler
def setup(rank, world_size):
"""初始化分布式环境"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
"""清理分布式环境"""
dist.destroy_process_group()
class SimpleModel(nn.Module):
"""简单的CNN模型"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc = nn.Linear(9216, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, 2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = torch.flatten(x, 1)
return self.fc(x)
def prepare_dataloader(rank, world_size, batch_size=32):
"""准备分布式数据加载器"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
return loader
def train(rank, world_size, epochs=2):
"""训练函数"""
setup(rank, world_size)
# 设置当前设备
torch.cuda.set_device(rank)
# 初始化模型、优化器等
model = SimpleModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
optimizer = optim.Adam(ddp_model.parameters())
scaler = GradScaler() # 混合精度训练
criterion = nn.CrossEntropyLoss()
train_loader = prepare_dataloader(rank, world_size)
for epoch in range(epochs):
ddp_model.train()
train_loader.sampler.set_epoch(epoch) # 确保每个epoch有不同的shuffle
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
# 混合精度训练
with torch.autocast(device_type='cuda', dtype=torch.float16):
output = ddp_model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if batch_idx % 100 == 0:
print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")
cleanup()
if __name__ == "__main__":
# 单机多卡启动时,torchrun会自动设置这些环境变量
rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
train(rank, world_size)
使用 torchrun
启动分布式训练(推荐 PyTorch 1.9+):
# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
ddp_model = DDP(model, device_ids=[rank])
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):
# 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
数据加载
DistributedSampler
sampler.set_epoch(epoch)
保证 shuffle 正确性模型保存
if rank == 0: # 只在主进程保存
torch.save(model.state_dict(), "model.pth")
多机训练
# 机器1 (主节点)
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py
# 机器2
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
性能调优
batch_size
使各 GPU 负载均衡pin_memory=True
加速数据加载CUDA 内存不足
batch_size
for i, (data, target) in enumerate(train_loader):
if i % 2 == 0:
optimizer.zero_grad()
# 前向和反向...
if i % 2 == 1:
optimizer.step()
进程同步失败
MASTER_ADDR
和 MASTER_PORT
一致精度问题
GradScaler
参数scaler = GradScaler(init_scale=1024, growth_factor=2.0)