NPUBroadcastGlobalVariablesCallback构造函数
产品支持情况
产品  | 
是否支持  | 
|---|---|
√  | 
|
x  | 
|
x  | 
|
√  | 
|
√  | 
|
x  | 
 针对
功能说明
Keras场景下对变量进行广播,使得在分布式场景下每个device上的变量初始值保持一致。
函数原型
1 | def __init__(self, root_rank)  | 
参数说明
参数名  | 
输入/输出  | 
描述  | 
|---|---|---|
root_rank  | 
输入  | 
标识将哪个device的变量广播到其他的device上。  | 
返回值
无
调用示例
迁移前:
1 2 3 4 5 6  | callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0)] import numpy as np data = np.random.random((1000, 100)) labels = np random.randint(2, size=(1000,1)) model.fit(data, labels, epochs=10, batch_size=32, callbacks=callbacks)  | 
迁移后:
1 2 3 4 5 6 7  | from npu_bridge.npu_init import * callbacks = [NPUBroadcastGlobalVariablesCallback(0)] import numpy as np data = np.random.random((1000, 100)) labels = np random.randint(2, size=(1000,1)) model.fit(data, labels, epochs=10, batch_size=32, callbacks=callbacks)  |