slurm提交Tensorflow任务实现

主要目的

目前tensorflow单机多卡模式可以参考tutorial很容易使用,但是如果想在集群多节点搭建分布式tensorflow训练任务部署,官方没有一个很好的示例代码,只能通过很naive的方法,指定ps node/worker node,在不同的节点分别执行对应的程序来实现多机协同训练模型的效果.这种方式对于集群环境,存在大量节点的情况就显得非常的不方便.本文是基于slurm集群资源管理工具,实现分布式tensorflow训练任务的分发.

实现

#定义function用与读取slurm提交一个任务后,分配的集群计算资源.
#传递两个参数:ps_number代表需要的parameter server节点个数,默认剩余其它节点均作为worker节点.
#作为ps的node也可以作为worker,但是为了避免端口的冲突,我们不这么做.
#port_number传递本次任务多节点通信的端口,如果ps所在的node同时还启动了worker进程,那么不同的worker进程需要指定不同的端口,为方便,默认使用的节点个数num_nodes>1,worker与ps不分配在相同节点.

def tf_config_from_slurm(ps_number, port_number=2222):
    """
    Creates configuration for a distributed tensorflow session 
    from environment variables  provided by the Slurm cluster
    management system.

    @param: ps_number number of parameter servers to run
    @param: port_number port number to be used for communication
    @return: a tuple containing cluster with fields cluster_spec,
             task_name and task_id 
    """

    nodelist = os.environ["SLURM_JOB_NODELIST"]
    print(nodelist)
    print("jacob")
    nodename = os.environ["SLURMD_NODENAME"]
    nodelist = _expand_nodelist(nodelist)
    num_nodes = int(os.getenv("SLURM_JOB_NUM_NODES"))

    if len(nodelist) != num_nodes:
        raise ValueError("Number of slurm nodes {} not equal to {}".format(len(nodelist), num_nodes))

    if nodename not in nodelist:
        raise ValueError("Nodename({}) not in nodelist({}). This should not happen! ".format(nodename,nodelist))
  if ps_number > num_nodes :
        raise ValueError("Number of ps node is largger than nodes be given by slurm!")
    ps_nodes = [node for i, node in enumerate(nodelist) if i < ps_number]
    worker_nodes = [node for i, node in enumerate(nodelist) if i >= ps_number]

    if nodename in ps_nodes:
        my_job_name = "ps"
        my_task_index = ps_nodes.index(nodename)
    else:
        my_job_name = "worker"
        my_task_index = worker_nodes.index(nodename)

    worker_sockets = [":".join([node, str(port_number)]) for node in worker_nodes]
    ps_sockets = [":".join([node, str(port_number)]) for node in ps_nodes]
    cluster = {"worker": worker_sockets, "ps" : ps_sockets}

    return cluster, my_job_name, my_task_index

def _pad_zeros(iterable, length):
    return (str(t).rjust(length, '0') for t in iterable)
def _expand_ids(ids):
    ids = ids.split(',')
    result = []
    for id in ids:
        if '-' in id:
            begin, end = [int(token) for token in id.split('-')]
            result.extend(_pad_zeros(range(begin, end+1), len(token)))
        else:
            result.append(id)
    return result

def _expand_nodelist(nodelist):
    prefix, ids = re.findall("(.*)\[(.*)\]", nodelist)[0]
    ids = _expand_ids(ids)
    result = [prefix + str(id) for id in ids]
    return result

def _worker_task_id(nodelist, nodename):
    return nodelist.index(nodename)

tensorflow构建网络模型

# 获取slurm分配的集群计算资源,以及当前执行节点的job name,配置clusterspec并启动server.
# 另外需要注意的是ps节点因为要保持接收worker的消息,完成参数的同步更新,所以其服务需要一直join,不能直接退出.
cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=3)
cluster_spec = tf.train.ClusterSpec(cluster)
server = tf.train.Server(server_or_cluster_def=cluster_spec,
                         job_name=my_job_name,
                         task_index=my_task_index)

if my_job_name == 'ps':
    server.join()
    sys.exit(0)
后续完善后继续更新

你可能感兴趣的:(深度学习,分布式,集群,tensorflow)