文章目录
- 0. 前言
- 1. Distributed Reinforcement Learning using RPC and RRef
- 2. Distributed RNN using Distributed Autograd and Distributed Optimizer
0. 前言
- 官方教程,中文翻译
- 目标:给出两个实例,介绍如果通过
torch.distributed.rpc
来实现分布式训练。
DistributedDataParallel
不适用的情况
- 强化学习中,模型本身小,但从环境中获取训练数据昂贵,需要建立多个observer核一个agent。
- 单个大模型,一个GPU放不下,即模型并行。
- 看的不是特别懂,感觉用到了再仔细研究。
1. Distributed Reinforcement Learning using RPC and RRef
- 距离说明,如何通过RPC与RRef实现强化学习模型(CartPole-v1)。
- 主要实现的功能
- 展示如何通过RPC在多个workers见进行数据传输。
- 展示如何使用
RRef
来代表 remote objects。
- 通过
torch.distributed.rpc
可以得到原生支持以及优化。
- 模型代码如下
import torch.nn as nn
import torch.nn.functional as F
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128)
self.dropout = nn.Dropout(p=0.6)
self.affine2 = nn.Linear(128, 2)
self.saved_log_probs = []
self.rewards = []
def forward(self, x):
x = self.affine1(x)
x = self.dropout(x)
x = F.relu(x)
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)
- 构建 helepr function 用来远程调用 RRef 中 owner worker 的函数。
from torch.distributed.rpc import rpc_sync
def _call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)
def _remote_method(method, rref, *args, **kwargs):
args = [method, rref] + list(args)
return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)
- 介绍observer。
- 每个observer创建自己的环境,等待agent来运行episode。
- 每个episode中,每个observer循环
n_steps
次。
- 每次循环使用RPC将环境状态传递给agent并得到反馈。
- 将反馈结果用于当前环境,得到reward以及下一个环境状态。
- 之后,observer利用另外一个rpc连接将reward传递给agent。
import argparse
import gym
import torch.distributed.rpc as rpc
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=2, help='Number of workers')
parser.add_argument('--log_interval', default=1, help='Log every log_interval episodes')
parser.add_argument('--gamma', default=0.1, help='how much to value future rewards')
parser.add_argument('--seed', default=1, help='random seed for reproducibility')
args = parser.parse_args()
class Observer:
def __init__(self):
self.id = rpc.get_worker_info().id
self.env = gym.make('CartPole-v1')
self.env.seed(args.seed)
def run_episode(self, agent_rref, n_steps):
state, ep_reward = self.env.reset(), 0
for step in range(n_steps):
action = _remote_method(Agent.select_action, agent_rref, self.id, state)
state, reward, done, _ = self.env.step(action)
_remote_method(Agent.report_reward, agent_rref, self.id, reward)
if done:
break
- 下面介绍agent
- agent的作用包括了trainer与master,即发送命令给observer从而运行episode,同时记录所有actions以及reward。
import gym
import numpy as np
import torch
import torch.distributed.rpc as rpc
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
class Agent:
def __init__(self, world_size):
self.ob_rrefs = []
self.agent_rref = RRef(self)
self.rewards = {}
self.saved_log_probs = {}
self.policy = Policy()
self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
self.eps = np.finfo(np.float32).eps.item()
self.running_reward = 0
self.reward_threshold = gym.make('CartPole-v1').spec.reward_threshold
for ob_rank in range(1, world_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
self.ob_rrefs.append(remote(ob_info, Observer))
self.rewards[ob_info.id] = []
self.saved_log_probs[ob_info.id] = []
def select_action(self, ob_id, state):
state = torch.from_numpy(state).float().unsqueeze(0)
probs = self.policy(state)
m = Categorical(probs)
action = m.sample()
self.saved_log_probs[ob_id].append(m.log_prob(action))
return action.item()
def report_reward(self, ob_id, reward):
self.rewards[ob_id].append(reward)
def run_episode(self, n_steps=0):
futs = []
for ob_rref in self.ob_rrefs:
futs.append(
rpc_async(
ob_rref.owner(),
_call_method,
args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps)
)
)
for fut in futs:
fut.wait()
def finish_episode(self):
R, probs, rewards = 0, [], []
for ob_id in self.rewards:
probs.extend(self.saved_log_probs[ob_id])
rewards.extend(self.rewards[ob_id])
min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards])
self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
for ob_id in self.rewards:
self.rewards[ob_id] = []
self.saved_log_probs[ob_id] = []
policy_loss, returns = [], []
for r in rewards[::-1]:
R = r + args.gamma * R
returns.insert(0, R)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + self.eps)
for log_prob, R in zip(probs, returns):
policy_loss.append(-log_prob * R)
self.optimizer.zero_grad()
policy_loss = torch.cat(policy_loss).sum()
policy_loss.backward()
self.optimizer.step()
return min_reward
import os
from itertools import count
import torch.multiprocessing as mp
AGENT_NAME = "agent"
OBSERVER_NAME="obs"
TOTAL_EPISODE_STEP = 100
def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
if rank == 0:
rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)
agent = Agent(world_size)
for i_episode in count(1):
n_steps = int(TOTAL_EPISODE_STEP / (args.world_size - 1))
agent.run_episode(n_steps=n_steps)
last_reward = agent.finish_episode()
if i_episode % args.log_interval == 0:
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
i_episode, last_reward, agent.running_reward))
if agent.running_reward > agent.reward_threshold:
print("Solved! Running reward is now {}!".format(agent.running_reward))
break
else:
rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
rpc.shutdown()
mp.spawn(
run_worker,
args=(args.world_size, ),
nprocs=args.world_size,
join=True
)
2. Distributed RNN using Distributed Autograd and Distributed Optimizer
- 目标:以RNN模型为例,通过RPC实现模型并行。
- 样例中的RNN模型很小,一个GPU就够了,但为了展示将模型分到两个GPU中。
- 模型主要分为三个部分,embedding table、LSTM layer、decoder。
class EmbeddingTable(nn.Module):
r"""
Encoding layers of the RNNModel
"""
def __init__(self, ntoken, ninp, dropout):
super(EmbeddingTable, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp).cuda()
self.encoder.weight.data.uniform_(-0.1, 0.1)
def forward(self, input):
return self.drop(self.encoder(input.cuda()).cpu()
class Decoder(nn.Module):
def __init__(self, ntoken, nhid, dropout):
super(Decoder, self).__init__()
self.drop = nn.Dropout(dropout)
self.decoder = nn.Linear(nhid, ntoken)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-0.1, 0.1)
def forward(self, output):
return self.decoder(self.drop(output))
class RNNModel(nn.Module):
def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))
def forward(self, input, hidden):
emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
output, hidden = self.rnn(emb, hidden)
decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
return decoded, hidden
def parameter_rrefs(self):
remote_params = []
remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref))
remote_params.extend(_parameter_rrefs(self.rnn))
remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref))
return remote_params
def _parameter_rrefs(module):
param_rrefs = []
for param in module.parameters():
param_rrefs.append(RRef(param))
return param_rrefs
def run_trainer():
batch = 5
ntoken = 10
ninp = 2
nhid = 3
nindices = 3
nlayers = 4
hidden = (
torch.randn(nlayers, nindices, nhid),
torch.randn(nlayers, nindices, nhid)
)
model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers)
opt = DistributedOptimizer(
optim.SGD,
model.parameter_rrefs(),
lr=0.05,
)
criterion = torch.nn.CrossEntropyLoss()
def get_next_batch():
for _ in range(5):
data = torch.LongTensor(batch, nindices) % ntoken
target = torch.LongTensor(batch, ntoken) % nindices
yield data, target
for epoch in range(10):
for data, target in get_next_batch():
with dist_autograd.context() as context_id:
hidden[0].detach_()
hidden[1].detach_()
output, hidden = model(data, hidden)
loss = criterion(output, target)
dist_autograd.backward(context_id, [loss])
opt.step(context_id)
print("Training epoch {}".format(epoch))
def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
if rank == 1:
rpc.init_rpc("trainer", rank=rank, world_size=world_size)
_run_trainer()
else:
rpc.init_rpc("ps", rank=rank, world_size=world_size)
pass
rpc.shutdown()
if __name__=="__main__":
world_size = 2
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)