pytorch Distribute分布式训练

from torch.multiprocessing import Process
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F


class Random_Dataset(Dataset):

    def __init__(self, num=1200, dim=300):
        self.num = num
        self.dim = dim
        self.data = torch.rand(num, dim)

    def __len__(self):
        return self.num

    def __getitem__(self, idx):
        return self.data[idx]


class Model(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers):
        super(Model, self).__init__()
        self.rnn = nn.LSTM(
                input_size=input_size, hidden_size=hidden_size,
                num_layers=num_layers, batch_first=True
                )
    
    def forward(self, x, lengths):
        total_length = x.size(1)
        self.rnn.flatten_parameters()
        x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
        outputs, (h_n, c_n) = self.rnn(x)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(
                outputs, batch_first=True, total_length=total_length
                )
        return outputs


class Generator(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size):
        super(Generator, self).__init__()
        padding = int((kernel_size - 1) / 2)
        self.conv = nn.Conv1d(
                in_channels, out_channels, kernel_size,
                padding=padding
                )

    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        x = self.conv(x)
        x = F.sigmoid(x)
        return x


def init_process(rank, world_size, backend, func, params):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
    group = dist.new_group([0, 1])
    func(*params, group)


def generate_data(dataloader, num_epochs, rank, device_ids, network_params,
                  use_cuda, group):
    generator = Generator(*network_params)
    if use_cuda:
        generator = generator.cuda()
    print("Network initialized!")
    generator = nn.parallel.distributed.DistributedDataParallel(
            generator, device_ids=device_ids
            )
    for epoch in range(num_epochs):
        for i_batch, batch in enumerate(dataloader):
            batch = batch.cuda()
            rnn_input = generator(batch)
            dist.broadcast(rnn_input, rank, group)
            print("epoch:{}, batch_num:{}, broadcast finished!".format(
                epoch, i_batch
                )
                )

def run_rnn(num_batchs, num_epochs, src_rank, device_ids, network_params,
            input_size, use_cuda, group):
    rnn = Model(*network_params)
    if use_cuda:
        rnn = rnn.cuda()
    print("Network initialized!")
    rnn = nn.parallel.distributed.DistributedDataParallel(
            rnn, device_ids=device_ids
            )
    optimizer = optim.Adam(rnn.parameters(), lr=1e-4)
    for epoch in range(num_epochs):
        for i_batch in range(num_batchs):
            optimizer.zero_grad()
            rnn_input = torch.Tensor(input_size).cuda()
            dist.broadcast(rnn_input, src_rank, group)
            print("epoch:{}, batch_num:{}, receive finished!".format(
                epoch, i_batch
                )
                )
            batch_size = rnn_inputs.size(0)
            lengths = np.random.randint(low=3, high=30, size=(batch_size))
            lengths = -np.sort(-lengths)
            lengths = torch.from_numpy(lengths).long().cuda()
            out = rnn(rnn_input, lengths)
            out = torch.sum(out)
            out.backward()
            optimizer.step()
            print("out:{}".format(out.item()))
    rnn = rnn.cpu()
    torch.save('rnn.net', rnn.state_dict)


def main(use_cuda=False):
    world_size = 2
    processes = []
    dataset = Random_Dataset()
    dataloader = DataLoader(
            dataset, batch_size=12, shuffle=True, num_workers=1
            )
    num_epochs = 2
    num_batchs = 100
    generator_device_ids = [0, 1]
    rnn_device_ids = [2, 3]
    generator_params = (1, 50, 5)
    rnn_params = (300, 300, 3)

    p1 = Process(
            target=init_process,
            args=(0, world_size, 'nccl', generate_data,
                  (dataloader, num_epochs, 0, generator_device_ids,
                   generator_params, use_cuda)
                  )
            )
    p1.start()
    processes.append(p1)

    p2 = Process(
            target=init_process,
            args=(1, world_size, 'gloo', run_rnn, 
                (num_batchs, num_epochs, 0,
                 rnn_device_ids, rnn_params, torch.Size((12, 50, 300)),
                 use_cuda)
                )
            )
    p2.start()
    processes.append(p2)

    for p in processes:
        p.join()


if __name__ == '__main__':
    main(True)

你可能感兴趣的:(pytorch)