npu.distribute.shard_and_rebatch_dataset

Description

Shards the dataset and global batch size for workers in distributed NPU training.

Prototype

npu.distribute.shard_and_rebatch_dataset(dataset, global_bs)

Parameters

Parameter

Input/Output

Description

dataset

Input

TensorFlow dataset type.

Dataset to be sharded.

global_bs

Input

Global batch size.

Returns

A tuple object of two elements, for the sharded datasets and the per-worker mini-batch size respectively.

Example

1
2
import npu_device as npu
dataset, batch_size = npu.distribute.shard_and_rebatch_dataset(dataset, batch_size)