工作之余自己对pytorch的DataLoader源码进行研读,简要总结并贴出各环节的主要代码和大家分享。
# dataloading ...
data_loader = torch.utils.data.DataLoader(
Dataset(self.args.__dict__, random_shift=True, data_mode='train',
transform=torchvision.transforms.Compose([
self.params_dict['augment'],
Stack(roll=True), # roll the RGB-channel
ToTorchFormatTensor(need_scale=False), # keep [0, 255]
normalize,
])
),
batch_size=self.args.batch_size, shuffle=True,
num_workers=self.args.workers, pin_memory=True, drop_last=True)
# training/testing
for epoch in range(num_epochs):
for idx, (input, label) in enumerate(data_loader):
......
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
class DataLoader(object):
__initialized = False
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
...
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with "shuffle"')
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
self.batch_size = None
self.drop_last = None
if self.num_workers < 0:
raise ValueError('num_workers option cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True
def __setattr__(self, attr, val):
if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__))
super(DataLoader, self).__setattr__(attr, val)
def __iter__(self):
return _DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
分析:
class SequentialSampler(Sampler):
''' Samples elements sequentially, always in the same order. '''
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
class RandomSampler(Sampler):
''' Samples elements randomly '''
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
...
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
def __len__(self):
return self.num_samples
class BatchSampler(Sampler):
'Wraps another sampler to yield a mini-batch of indices.'
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
''' generate the idx of mini-batch'''
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
说明:
1. 迭代器一定是迭代对象,迭代对象不一定是迭代器; 2. 遍历可迭代对象一般采用for...in...的格式 ;
2. 在DataLoader类中,对_DataLoaderIter类的调用格式如下:
def __iter__(self):
return _DataLoaderIter(self)
需要指出的是,_DataLoaderIter(self, loader)中的loader就是对应的DataLoader类的实例。
class _DataLoaderIter(object):
''' Iterates once over the DataLoader's dataset, as specified by the sampler '''
def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.sample_iter = iter(self.batch_sampler)
base_seed = torch.LongTensor(1).random_().item()
...
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.Queue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
self.done_event = multiprocessing.Event() # 事件监听
self.index_queues = []
self.workers = []
#进程的分配细节
for i in range(self.num_workers):
index_queue = multiprocessing.Queue() # 每个worker进程对应一个index_queue
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_utils.worker._worker_loop, # 核心环节,后面展开叙述它
args=(self.dataset, index_queue,
self.worker_result_queue, self.done_event,
self.collate_fn, base_seed + i,
self.worker_init_fn, i))
w.daemon = True # ensure that the worker exits on process exit
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
if self.pin_memory:
self.data_queue = queue.Queue() # global data_quence
...
self.pin_memory_thread = pin_memory_thread
else:
self.data_queue = self.worker_result_queue
...
# prime the prefetch loop
# 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中
for _ in range(2 * self.num_workers):
self._put_indices() # 核心环节
在_DataLoaderIter中,首先来看self.workers,它对应的是num_workers个用来处理数据的进程,对应的函数是_worker_loop.
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
'''
dataset : self.dataset
index_queue : multiprocessing.Queue()
self.data_queue : 所有的进程共享
self.collate_fn : 数据操作函数
base_seed + i : pid
self.worker_init_fn : worker初始化函数
i : 第几个workder, 即worker的index
'''
collate._use_shared_memory = True
...
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
data_queue.cancel_join_thread()
init_exception = None
# 初始化worker
if init_fn is not None:
init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if r is None:
assert done_event.is_set()
return
elif done_event.is_set():
continue
#注意,这里取的是batch的index,而不是sample的index
idx, batch_indices = r
try:
# 从dataset中读取index=i的数据,label
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
del samples
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
# 数据像挤牙膏一样,在需要的时候才取出来,而不是一次性取出来放在内存!!!
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
# 确保worker_queue_idx在0~num_workers-1之间
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
self.batches_outstanding += 1
# send_idx与self.sample_iter相伴随,它给送到各个线程的数据统一打上index,
self.send_idx += 1
参考资料:
https://blog.csdn.net/g11d111/article/details/81504637