npu.distribute.shard_and_rebatch_dataset
功能说明
用于NPU分布式部署场景下,不同worker上数据集分片及batch大小调整。
函数原型
npu.distribute.shard_and_rebatch_dataset(dataset, global_bs)
参数说明
| 参数名 | 输入/输出 | 描述 | 
|---|---|---|
| dataset | 输入 | TensorFlow的Dataset类型。 需要进行切分的数据集。 | 
| global_bs | 输入 | 全局batch的大小。 | 
返回值
返回一个2个元素的tuple对象,第一个元素为切分后的Dataset,第二个元素为每个worker应当处理的实际batch大小。
调用示例
| 1 2 | import npu_device as npu dataset, batch_size = npu.distribute.shard_and_rebatch_dataset(dataset, batch_size) |