点击上方“机器学习与生成对抗网络”,关注星标
获取有趣、好玩的前沿干货!
来源|知乎 作者|Envy
链接|https://zhuanlan.zhihu.com/p/409629586
编辑|人工智能前沿讲习
最近我在做PyTorch的Dataloader相关的开发,有一个问题让我比较在意:PyTorch的Dataloader在启动多个进程读取样本的时候,这些数据是怎么在进程之间进行传输的?会不会引入多余的内存拷贝?
01
Dataloader使用multiprocess.Queue来传输数据
先简单介绍一下Dataloader的多进程模式,Dataloader在构造的时候,若num_workers不为0,就会启动num_workers个worker进程,然后主进程会向worker进程分发读取任务,worker进程读到数据之后,再把数据放到队列中供主进程取用。Worker进程所执行的代码片段如下:
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if isinstance(r, _ResumeIteration):
# Acknowledge the main process
data_queue.put((r, None))
iteration_end = False
# Recreate the fetcher for worker-reuse policy
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collation, collate_fn, drop_last)
continue
elif r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, index = r
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
data = fetcher.fetch(index)
except Exception as e:
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
data = _IterableDatasetStopIteration(worker_id)
# Set `iteration_end`
# (1) to save future `next(...)` calls, and
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
iteration_end = True
else:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
data_queue.put((idx, data))
del data, idx, index, r # save memory
其中,data_queue通常是一个torch.multiprocessing.Queue的实例。
看到这里,似乎Dataloader只是平平无奇地使用了torch.multiprocessing.Queue的接口,难道torch.multiprocessing.Queue有一些高级的技巧?
继续看torch.multiprocessing.Queue的代码,发现它只是简单地把multiprocessing.Queue包了一下。众所周知, multiprocessing.Queue在Linux使用socket来实现,那难道读上来的数据需要在socket之间传来传去吗,效率也太低了吧?!不对,PyTorch一定还有其他骚操作。
02
Tensor(CPU Tensor)在multiprocessing.Queue中的序列化和反序列化
在通常的用法里,Dataloader从Dataset里读出来的数据都会被collate_fn转成CPU Tensor,那我们就继续看看,Tensor是怎么在队列中序列化和反序列化的。
可以看到,torch.Tensor重载了__reduce_ex__()函数,序列化的时候只会用到 Tensor.storage, Tensor.storage_offset, size, stride, requires_grad和backward_hooks;而反序列化的torch._utils._rebuild_tensor_v2()也只会用到以上信息。
multiprocessing.Queue是使用pickle来做序列化和反序列化的,而重载__reduce_ex__()正是自定义序列化反序列化方式的方法之一。
那看起来,CPU Tensor在进程中传输时,是在接收进程中把Tensor重新构建了一遍,而构建Tensor时候用到的信息, Tensor.storage_offset, size, stride, requires_grad和backward_hooks,都只是用于描述Tensor的meta信息,实际和数据相关的,就只有Tensor.storage了。
03
Tensor.Storage的序列化与反序列化
Tensor.Storage同样重载了pickle的序列化与反序列化过程,在torch/multiprocessing/reduction.py中,给Tensor.Storage 注册了reduce函数reduce_storage.
这里为什么使用copyreg库而不是重载__reduce__(), 在copyreg的注释里说copyreg是专用于C extension的.按照这个说法,在reductions.py里为Tensor注册的reduce function应该是没有起效的。
def reduce_storage(storage):
from . import get_sharing_strategy
if storage.is_cuda:
raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead")
elif get_sharing_strategy() == 'file_system':
metadata = storage._share_filename_()
cache_key = metadata[1]
rebuild = rebuild_storage_filename
storage._shared_incref()
elif storage.size() == 0:
# This is special cased because Empty tensors
# (with size 0) cannot be mmapped.
return (rebuild_storage_empty, (type(storage),))
else:
fd, size = storage._share_fd_()
df = multiprocessing.reduction.DupFd(fd)
cache_key = fd_id(fd)
metadata = (df, size)
rebuild = rebuild_storage_fd # type: ignore[assignment]
shared_cache[cache_key] = StorageWeakRef(storage)
return (rebuild, (type(storage),) + metadata)
这个函数首先根据环境中的sharing strategy来决定共享内存的使用方式,然后若storage原本不在共享内存中的话,就把它拷到共享内存中去,比如_share_fd_()的实现。
04
小结
看到这里,我们可以大概得出一个结论了。
Worker进程从Dataset中读出来的Tensor本身是普通的CPU Tensor,但当把它放到multiprocessing.Queue中去的时候,这个Tensor的数据会被拷到共享内存中,Queue只会发送这个Tensor所具有的meta信息,主进程接到这些meta信息之后,就可以从共享内存中的数据重新构建Tensor。
这里有一点值得注意,如果你想要验证这个过程,在发送进程调用multiprocessing.Queue.put()之后,立即调用Tensor.is_shared()并不会返回True,因为put()是非阻塞的,只有当Tensor被QueueFeedThread序列化完成之后再调用is_shared(),才会得到预期中的结果。
05
Dataloader的小心机
在default_collate中,有这样一段小代码:
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
含义是,如果batch中的数据已经是Tensor了,那么,如果这是一个Worker进程,就开一段共享内存把这个batch放进去。因为collate的时候无论如何都会有一次内存拷贝(除非底层的Dataset有其他保证),那么这个操作就省掉了之后放进队列中的那一次内存拷贝。
不过我随便找了几个自定义了collate_fn的模型看了一下他们写的collate过程,是没有把这一点考虑进去的。这也算是Dataloader的一个小心机吧,有缘人就用得上。
猜您喜欢:
等你着陆!【GAN生成对抗网络】知识星球!
CVPR 2021专题1:GAN的改进
CVPR 2021 | GAN的说话人驱动、3D人脸论文汇总
CVPR 2021 | 图像转换 今如何?几篇GAN论文
【CVPR 2021】通过GAN提升人脸识别的遗留难题
CVPR 2021生成对抗网络GAN部分论文汇总
经典GAN不得不读:StyleGAN
最新最全20篇!基于 StyleGAN 改进或应用相关论文
超100篇!CVPR 2020最全GAN论文梳理汇总!
附下载 | 《Python进阶》中文版
附下载 | 经典《Think Python》中文版
附下载 | 《Pytorch模型训练实用教程》
附下载 | 最新2020李沐《动手学深度学习》
附下载 | 《可解释的机器学习》中文版
附下载 |《TensorFlow 2.0 深度学习算法实战》
附下载 | 超100篇!CVPR 2020最全GAN论文梳理汇总!
附下载 |《计算机视觉中的数学方法》分享