昇腾社区首页
中文
注册

NPUBroadcastGlobalVariablesCallback构造函数

产品支持情况

产品

是否支持

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品

Atlas 800I A2 推理产品/A200I A2 Box 异构组件

x

Atlas 200I/500 A2 推理产品

x

Atlas 推理系列产品

x

Atlas 训练系列产品

Atlas 200/300/500 推理产品

x

针对Atlas A3 推理系列产品,仅支持在线推理特性。

功能说明

Keras场景下对变量进行广播,使得在分布式场景下每个device上的变量初始值保持一致。

函数原型

1
def __init__(self, root_rank)

参数说明

参数名

输入/输出

描述

root_rank

输入

标识将哪个device的变量广播到其他的device上。

返回值

调用示例

迁移前:

1
2
3
4
5
6
callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0)]

import numpy as np
data = np.random.random((1000, 100))
labels = np random.randint(2, size=(1000,1))
model.fit(data, labels, epochs=10, batch_size=32, callbacks=callbacks)

迁移后:

1
2
3
4
5
6
7
from npu_bridge.npu_init import *
callbacks = [NPUBroadcastGlobalVariablesCallback(0)]

import numpy as np
data = np.random.random((1000, 100))
labels = np random.randint(2, size=(1000,1))
model.fit(data, labels, epochs=10, batch_size=32, callbacks=callbacks)