npu.distribute.broadcast
产品支持情况
产品  | 
是否支持  | 
|---|---|
√  | 
|
x  | 
|
x  | 
|
x  | 
|
√  | 
|
x  | 
功能说明
用于NPU分布式部署场景下,worker间的变量同步。
函数原型
1 | npu.distribute.broadcast(values, root_rank, fusion=2, fusion_id=0, group="hccl_world_group")  | 
参数说明
参数名  | 
输入/输出  | 
描述  | 
|---|---|---|
values  | 
输入  | 
单个TensorFlow的Variable或者Variable的集合。 针对 针对 针对  | 
root_rank  | 
输入  | 
int类型。 作为root节点的rank_id,该id是group内的rank id。  | 
fusion  | 
输入  | 
int类型。 broadcast算子融合标识,支持以下取值: 
  | 
fusion_id  | 
输入  | 
int类型。 broadcast算子的融合id。 当“fusion”取值为“2”时,网络编译时会对相同fusion_id的broadcast算子进行融合。  | 
group  | 
输入  | 
String类型,最大长度为128字节,含结束符。 group名称,可以为用户自定义group或者"hccl_world_group"。  | 
返回值
无。
调用示例
将0卡上的变量广播到其他卡:
1 2 3 4 5 6  | # rank_id = 0 rank_size = 8 import npu_device as npu x = tf.Variable(tf.random.normal(shape=())) print("before broadcast", x) npu.distribute.broadcast(x, root_rank=0) print("after_broadcast", x)  | 
广播前:

广播后:
