PyTorch Dataloader读取时,如何在进程之间传输数据?

点击上方“机器学习与生成对抗网络”,关注星标

获取有趣、好玩的前沿干货!

来源|知乎  作者|Envy

链接|https://zhuanlan.zhihu.com/p/409629586

编辑|人工智能前沿讲习

最近我在做PyTorch的Dataloader相关的开发,有一个问题让我比较在意:PyTorch的Dataloader在启动多个进程读取样本的时候,这些数据是怎么在进程之间进行传输的?会不会引入多余的内存拷贝?

01

Dataloader使用multiprocess.Queue来传输数据

先简单介绍一下Dataloader的多进程模式,Dataloader在构造的时候,若num_workers不为0,就会启动num_workers个worker进程,然后主进程会向worker进程分发读取任务,worker进程读到数据之后,再把数据放到队列中供主进程取用。Worker进程所执行的代码片段如下:

while watchdog.is_alive():
    try:
        r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
    except queue.Empty:
        continue
    if isinstance(r, _ResumeIteration):
        # Acknowledge the main process
        data_queue.put((r, None))
        iteration_end = False
        # Recreate the fetcher for worker-reuse policy
        fetcher = _DatasetKind.create_fetcher(
            dataset_kind, dataset, auto_collation, collate_fn, drop_last)
        continue
    elif r is None:
        # Received the final signal
        assert done_event.is_set() or iteration_end
        break
    elif done_event.is_set() or iteration_end:
        # `done_event` is set. But I haven't received the final signal
        # (None) yet. I will keep continuing until get it, and skip the
        # processing steps.
        continue
    idx, index = r
    data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
    if init_exception is not None:
        data = init_exception
        init_exception = None
    else:
        try:
            data = fetcher.fetch(index)
        except Exception as e:
            if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
                data = _IterableDatasetStopIteration(worker_id)
                # Set `iteration_end`
                #   (1) to save future `next(...)` calls, and
                #   (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
                iteration_end = True
            else:
                # It is important that we don't store exc_info in a variable.
                # `ExceptionWrapper` does the correct thing.
                # See NOTE [ Python Traceback Reference Cycle Problem ]
                data = ExceptionWrapper(
                    where="in DataLoader worker process {}".format(worker_id))
    data_queue.put((idx, data))
    del data, idx, index, r  # save memory

其中,data_queue通常是一个torch.multiprocessing.Queue的实例。

看到这里,似乎Dataloader只是平平无奇地使用了torch.multiprocessing.Queue的接口,难道torch.multiprocessing.Queue有一些高级的技巧?

继续看torch.multiprocessing.Queue的代码,发现它只是简单地把multiprocessing.Queue包了一下。众所周知, multiprocessing.Queue在Linux使用socket来实现,那难道读上来的数据需要在socket之间传来传去吗,效率也太低了吧?!不对,PyTorch一定还有其他骚操作。

02

Tensor(CPU Tensor)在multiprocessing.Queue中的序列化和反序列化

在通常的用法里,Dataloader从Dataset里读出来的数据都会被collate_fn转成CPU Tensor,那我们就继续看看,Tensor是怎么在队列中序列化和反序列化的。

可以看到,torch.Tensor重载了__reduce_ex__()函数,序列化的时候只会用到 Tensor.storage, Tensor.storage_offset, size, stride, requires_grad和backward_hooks;而反序列化的torch._utils._rebuild_tensor_v2()也只会用到以上信息。

multiprocessing.Queue是使用pickle来做序列化和反序列化的,而重载__reduce_ex__()正是自定义序列化反序列化方式的方法之一。

那看起来,CPU Tensor在进程中传输时,是在接收进程中把Tensor重新构建了一遍,而构建Tensor时候用到的信息, Tensor.storage_offset, size, stride, requires_grad和backward_hooks,都只是用于描述Tensor的meta信息,实际和数据相关的,就只有Tensor.storage了。

03

Tensor.Storage的序列化与反序列化

Tensor.Storage同样重载了pickle的序列化与反序列化过程,在torch/multiprocessing/reduction.py中,给Tensor.Storage 注册了reduce函数reduce_storage.

这里为什么使用copyreg库而不是重载__reduce__(), 在copyreg的注释里说copyreg是专用于C extension的.按照这个说法,在reductions.py里为Tensor注册的reduce function应该是没有起效的。

def reduce_storage(storage):
    from . import get_sharing_strategy
    if storage.is_cuda:
        raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead")
    elif get_sharing_strategy() == 'file_system':
        metadata = storage._share_filename_()
        cache_key = metadata[1]
        rebuild = rebuild_storage_filename
        storage._shared_incref()
    elif storage.size() == 0:
        # This is special cased because Empty tensors
        # (with size 0) cannot be mmapped.
        return (rebuild_storage_empty, (type(storage),))
    else:
        fd, size = storage._share_fd_()
        df = multiprocessing.reduction.DupFd(fd)
        cache_key = fd_id(fd)
        metadata = (df, size)
        rebuild = rebuild_storage_fd  # type: ignore[assignment]


    shared_cache[cache_key] = StorageWeakRef(storage)
    return (rebuild, (type(storage),) + metadata)

这个函数首先根据环境中的sharing strategy来决定共享内存的使用方式,然后若storage原本不在共享内存中的话,就把它拷到共享内存中去,比如_share_fd_()的实现。


04

小结

看到这里,我们可以大概得出一个结论了。

Worker进程从Dataset中读出来的Tensor本身是普通的CPU Tensor,但当把它放到multiprocessing.Queue中去的时候,这个Tensor的数据会被拷到共享内存中,Queue只会发送这个Tensor所具有的meta信息,主进程接到这些meta信息之后,就可以从共享内存中的数据重新构建Tensor。

这里有一点值得注意,如果你想要验证这个过程,在发送进程调用multiprocessing.Queue.put()之后,立即调用Tensor.is_shared()并不会返回True,因为put()是非阻塞的,只有当Tensor被QueueFeedThread序列化完成之后再调用is_shared(),才会得到预期中的结果。

05

Dataloader的小心机

在default_collate中,有这样一段小代码:

if isinstance(elem, torch.Tensor):
    out = None
    if torch.utils.data.get_worker_info() is not None:
        # If we're in a background process, concatenate directly into a
        # shared memory tensor to avoid an extra copy
        numel = sum([x.numel() for x in batch])
        storage = elem.storage()._new_shared(numel)
        out = elem.new(storage)
    return torch.stack(batch, 0, out=out)

含义是,如果batch中的数据已经是Tensor了,那么,如果这是一个Worker进程,就开一段共享内存把这个batch放进去。因为collate的时候无论如何都会有一次内存拷贝(除非底层的Dataset有其他保证),那么这个操作就省掉了之后放进队列中的那一次内存拷贝。

不过我随便找了几个自定义了collate_fn的模型看了一下他们写的collate过程,是没有把这一点考虑进去的。这也算是Dataloader的一个小心机吧,有缘人就用得上。

猜您喜欢:

等你着陆!【GAN生成对抗网络】知识星球!

CVPR 2021专题1:GAN的改进

CVPR 2021 | GAN的说话人驱动、3D人脸论文汇总

CVPR 2021 | 图像转换 今如何?几篇GAN论文

【CVPR 2021】通过GAN提升人脸识别的遗留难题

CVPR 2021生成对抗网络GAN部分论文汇总

经典GAN不得不读:StyleGAN

最新最全20篇!基于 StyleGAN 改进或应用相关论文

超100篇!CVPR 2020最全GAN论文梳理汇总!

附下载 | 《Python进阶》中文版

附下载 | 经典《Think Python》中文版

附下载 | 《Pytorch模型训练实用教程》

附下载 | 最新2020李沐《动手学深度学习》

附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 | 超100篇!CVPR 2020最全GAN论文梳理汇总!

附下载 |《计算机视觉中的数学方法》分享

PyTorch Dataloader读取时,如何在进程之间传输数据?_第1张图片

你可能感兴趣的:(人工智能,深度学习,java,编程语言,大数据)