使用MultiEpochsDataLoader
此场景需增加源码并使能,请用户根据性能情况判断是否需要替换,若性能差异较大且在数据处理时存在耗时问题推荐使用此方案。
使用场景
DataLoader在每个epoch开始的时候都会重新创建一次,因此每个epoch开始所有的worker会重新开始prefetching过程,就会引起数据读取过程的耗时。可以通过使用MultiEpochsDataLoader以减少重新创建epoch造成的数据读取耗时。
操作步骤
MultiEpochsDataLoader相关源码如下,使用该部分代码替换原有dataloader,并在代码中调用此方法。
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
父主题: 优化数据处理