使用数据预取
此场景需增加源码并使能,请用户根据性能情况判断是否需要替换,若性能差异较大推荐且耗时问题出现在数据处理可使用此方案。
使用场景
在使用数量较多的数据集时,可使用数据预取功能,在训练过程中每次数据返回给网络的时候,预读取下一次迭代需要的数据。
操作步骤
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
# Need to make sure the memory allocated for next_* is not still in use by the main stream
# at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if input is not None:
input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
return input, target
修改前:
train_datasets = CustomData()
train_dataloaders = torch.utils.data.Dataloader(train_datasets, shuffle=True)
for data in train_dataloaders:
current_iter + = 1
inputs, labels = data
修改后:
train_datasets = CustomData()
train_dataloaders = torch.utils.data.Dataloader(train_datasets, shuffle=True)
#增加数据预取
prefetcher = data_prefetcher(train_dataloaders)
inputs, lables = data = prefetcher.next()
while inputs is not None:
current_iter + = 1
inputs, labels = prefetcher.next()
父主题: 优化数据处理