reduce_scatter
Applicability
Product |
Supported |
|---|---|
√ |
|
√ |
|
☓ |
|
√ |
|
√ |
For the
Description
Functions as the operation API of the ReduceScatter operator to evenly divide the input data of all ranks in a communicator into rank size parts and then perform reduction (sum, prod, max, and min) on 1/rank size part of data of each rank. Finally, the result is distributed to the output buffer of each rank based on the rank ID.

Function Prototype
1 | def reduce_scatter(tensor, reduction, rank_size, group="hccl_world_group", fusion=0, fusion_id=-1) |
Parameters
Option |
Input/Output |
Description |
|---|---|---|
tensor |
Input |
TensorFlow tensor type. For the For the For the For the Atlas 300I Duo inference card, the supported data types are int8, int16, int32, float16, and float32. Note that the size of the first dimension of a tensor must be an integer multiple of the rank size. |
reduction |
Input |
String type. Reduction operation types, which can be max, min, prod, and sum. NOTE:
For the For the For the Atlas 300I Duo inference card, the prod, max, and min operations do not support the int16 data type in the current version. |
rank_size |
Input |
Int type. Number of devices in a group. Maximum value: 32768. |
group |
Input |
A string containing a maximum of 128 bytes, including the end character. Group name, which can be a user-defined value or hccl_world_group. |
fusion |
Input |
Int type. ReduceScatter operator fusion flag. The values are as follows:
|
fusion_id |
Input |
Int type. ReduceScatter operator fusion ID. If fusion is set to 2, ReduceScatter operators with the same fusion_id are fused during network compilation. |
Returns
The result tensor
Constraints
- The caller rank must be within the range defined by the group argument passed to this API call. Otherwise, the API call fails.
- The input tensor size must be less than or equal to 8 GB.
- For the ReduceScatter operator fusion, only the reduction type sum is supported.
Example
1 2 3 4 | from npu_bridge.hccl import hccl_ops tensor = tf.random_uniform((2, 3), minval=1, maxval=10, dtype=tf.float32) rank_size = 2 result = hccl_ops.reduce_scatter(tensor, "sum", rank_size) |