npu.distribute.shard_and_rebatch_dataset
产品支持情况
产品  | 
是否支持  | 
|---|---|
√  | 
|
x  | 
|
x  | 
|
x  | 
|
√  | 
|
x  | 
功能说明
用于NPU分布式部署场景下,不同worker上数据集分片及batch大小调整。
函数原型
1 | npu.distribute.shard_and_rebatch_dataset(dataset, global_bs)  | 
参数说明
参数名  | 
输入/输出  | 
描述  | 
|---|---|---|
dataset  | 
输入  | 
TensorFlow的Dataset类型。 需要进行切分的数据集。  | 
global_bs  | 
输入  | 
全局batch的大小。  | 
返回值
返回一个2个元素的tuple对象,第一个元素为切分后的Dataset,第二个元素为每个worker应当处理的实际batch大小。
调用示例
1 2  | import npu_device as npu dataset, batch_size = npu.distribute.shard_and_rebatch_dataset(dataset, batch_size)  |