MNIST 数据并行 Data Parallel - DP

Data Parallel

转自我的个人博客:https://shar-pen.github.io/2025/05/04/torch-distributed-series/2.MNIST_DP/

数据并行 vs. 模型并行

  • 数据并行:模型拷贝(per device),数据 split/chunk(batch 上)

    • the module is replicated on each device, and each replica handles a portion of the input.
    • During the backwards pass, gradients from each replica are summed into the original module.
  • 模型并行:数据拷贝(per device),模型 split/chunk(显然是单卡放不下模型的情况下)

DP 说的直白点,就是把通常意义上的 batch 切分到各个卡上, 在主卡的控制下实现前向传播和反向传播。

官方文档: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html

参数:

  • module, 需要 DP 的 torch.nn.module, i.e., 你的 model
  • device_ids=None, 参与训练的 GPU 有哪些,device_ids=gpus
  • output_device=None, 用于汇总梯度的 GPU 是哪个,output_device=gpus[0]
  • dim=0, 数据切分的维度, 一般就是第一维度 batch 维度来切分, dim = 0 [30, xxx] -> [10, …], [10, …], [10, …] on 3 GPUs

The parallelized module must have its parameters and buffers on device_ids[0] before running(forward/backward) this DataParallel module. 模型参数必须先缓存在给定卡中的第一张卡上。

如果我只执行了 nn.DataParallel, 没有执行 model.to(device) device 必须是选中的第一张卡, 否则会报错 RuntimeError: module must have its parameters and buffers on device cuda:4 (device_ids[0]) but found one of them on device: cpu. PS: 你没先缓存到正确的卡上

class ConvNet(nn.Module):
    
    def __init__(self):
        super(ConvNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25)
        )
        self.classifier = nn.Sequential(
            nn.Linear(9216, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )

    def forward(self, input):
        x = self.features(input)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        output = F.log_softmax(x, dim=1)
        # 输出 forward 变量形状,主要关注 batch 大小
        print(f"[Inside]: input shape: {input.size()}, label shape: {output.size()}")
        return output

单卡 forward

device = torch.device(f"cuda:4" if torch.cuda.is_available() else "cpu")
model = ConvNet()
model = model.to(device)
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    print(f"[Outside]: input shape: {data.size()}, label shape: {target.size()}")
    output = model(data)
    break

DP forward

CUDA_DEVICE_IDS = [4,5]
device = torch.device(f"cuda:{CUDA_DEVICE_IDS[0]}" if torch.cuda.is_available() else "cpu")
model = ConvNet()
model = model.to(device)
model = nn.DataParallel(model, device_ids=CUDA_DEVICE_IDS)
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    output = model(data)
    print(f"[Outside]: input shape: {data.size()}, label shape: {target.size()}, output shape: {output.size()}")
    break

forward 参数对比 (单卡和DP-2卡)


单卡
[Outside]: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512])
[Inside]: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512, 10])

DP-2卡
[Inside]: input shape: torch.Size([256, 1, 28, 28]), label shape: torch.Size([256, 10])
[Inside]: input shape: torch.Size([256, 1, 28, 28]), label shape: torch.Size([256, 10])
[Outside]: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512]), output shape: torch.Size([512, 10])

用DP情况下虽然循环里每次的 batch 大小还是一样的, 但模型 forward 确实将 batch / len(device_ids), 原来 512 的 batch 变为 256, 两个卡上各自有一个模型分别跑了 1 / len(device_ids) 的数据。

单卡
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA GeForce RTX 4090        On  |   00000000:81:00.0 Off |                  Off |
| 46%   33C    P8             14W /  450W |     719MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA GeForce RTX 4090        On  |   00000000:A1:00.0 Off |                  Off |
| 43%   32C    P8             20W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

DP-2卡
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA GeForce RTX 4090        On  |   00000000:81:00.0 Off |                  Off |
| 45%   34C    P2             55W /  450W |     717MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA GeForce RTX 4090        On  |   00000000:A1:00.0 Off |                  Off |
| 43%   33C    P2             57W /  450W |     719MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

每个卡上都有一份模型参数和 batch / len(device_ids) 的数据

DP 训练

实际只是增加了 model = nn.DataParallel(model, device_ids=CUDA_DEVICE_IDS), 后续 forward 和 backpropagation 都不需要改变

CUDA_DEVICE_IDS = [4,5]
DEVICE = torch.device(f"cuda:{CUDA_DEVICE_IDS[0]}" if torch.cuda.is_available() else "cpu")


model = ConvNet()
model = model.to(DEVICE)
model = nn.DataParallel(model, device_ids=CUDA_DEVICE_IDS)
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)

start_time = time()  # Record the start time
for epoch in range(EPOCHS):

    epoch_start_time = time()  # Record the start time of the current epoch
    print(f'Epoch {epoch}/{EPOCHS}')
    print(f'Learning Rate: {scheduler.get_last_lr()[0]}')

    train(model, DEVICE, train_loader, optimizer)
    test(model, DEVICE, test_loader)
    scheduler.step()

    epoch_end_time = time()  # Record the end time of the current epoch
    print(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")

end_time = time()  # Record the end time
print(f"Total training time: {end_time - start_time:.2f} seconds")

看下 nn.DataParallel 的内部 forward 函数, 有几行代码显示了大致流程

inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids) # 分散数据 inputs
replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) # 复制模型 replicas
outputs = self.parallel_apply(replicas, inputs, module_kwargs) # 并行计算 outputs
return self.gather(outputs, self.output_device) # 合并结果 gather

模型输出(outputs)来自多个子 GPU,但会 在主卡上 gather,因为 torch.nn.DataParallel 的默认行为是把所有子 GPU 的输出,gather 回主 GPU(device[0])。

为什么要 gather, 因为 label (targets)通常在主卡上,所以为了计算 loss,需要把输出也 gather 到主卡,才能和 labels 对应。PS:计算损失是要求参数在同一个 device 上。

loss.backward() 触发 autograd,它会根据 gather 的结构把 grad_output 自动 scatter 回子卡,每张子卡用自己的输出执行 backward。最终每个 gpu 的 gradient 都还要进行统一的更新,将梯度聚合再下方梯度,即 all-reduce。

实现 DP 的一种经典编程框架叫 “参数服务器” parameter server,在这个框架里,计算 GPU 称为 Worker,**梯度聚合 GPU 称为 Server。**在实际应用中,为了尽量减少通讯量,一般可选择一个 Worker 同时作为 Server。比如可把梯度全发到 GPU0 上做聚合。DP 的通信瓶颈在于 server 的通信开销,sever 没法一次性立马接受所有数据,所以当 worker 无法传输数据且已经计算完成时,它就只能摸鱼。

!!! 注意:DP 已经不被推荐了

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from time import time
import argparse

class ConvNet(nn.Module):

    def __init__(self):
        super(ConvNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25)
        )
        self.classifier = nn.Sequential(
            nn.Linear(9216, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        output = F.log_softmax(x, dim=1)
        return output


def arg_parser():
    parser = argparse.ArgumentParser(description="MNIST Training Script")
    parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training")
    parser.add_argument("--lr", type=float, default=0.0005, help="Learning rate")
    parser.add_argument("--lr_decay_step_num", type=int, default=1, help="Step size for learning rate decay")
    parser.add_argument("--lr_decay_factor", type=float, default=0.5, help="Factor by which learning rate is decayed")
    parser.add_argument("--cuda_ids", type=int, nargs='+', default=[0,1], help="List of CUDA device IDs to use")
    return parser.parse_args()


def prepare_data(batch_size):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_data = datasets.MNIST(
        root = './mnist',
        train=True,       # 设置True为训练数据,False为测试数据
        transform = transform,
        # download=True  # 设置True后就自动下载,下载完成后改为False即可
    )

    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

    test_data = datasets.MNIST(
        root = './mnist',
        train=False,       # 设置True为训练数据,False为测试数据
        transform = transform,
    )

    test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader


def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if (batch_idx + 1) % 30 == 0: 
            print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))



def train_mnist_classification():
    args = arg_parser()
    print(args)

    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size
    LR = args.lr
    LR_DECAY_STEP_NUM = args.lr_decay_step_num
    LR_DECAY_FACTOR = args.lr_decay_factor
    CUDA_DEVICE_IDS = args.cuda_ids
    DEVICE = torch.device(f"cuda:{CUDA_DEVICE_IDS[0]}")

    train_loader, test_loader = prepare_data(BATCH_SIZE)

    model = ConvNet().to(DEVICE)
    model = nn.DataParallel(model, device_ids=CUDA_DEVICE_IDS)
    optimizer = optim.Adam(model.parameters(), lr=LR)

    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)

    start_time = time()  # Record the start time
    for epoch in range(EPOCHS):

        epoch_start_time = time()  # Record the start time of the current epoch
        print(f'Epoch {epoch}/{EPOCHS}')
        print(f'Learning Rate: {scheduler.get_last_lr()[0]}')

        train(model, DEVICE, train_loader, optimizer)
        test(model, DEVICE, test_loader)
        scheduler.step()

        epoch_end_time = time()  # Record the end time of the current epoch
        print(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")

    end_time = time()  # Record the end time
    print(f"Total training time: {end_time - start_time:.2f} seconds")


if __name__ == "__main__":
    train_mnist_classification()

你可能感兴趣的:(MNIST,数据并行)