定制WarmStart

使用介绍

如需在Rec SDK训练框架使用定制WarmStart功能,需要在模型代码中创建NPUEstimator对象之前创建一个warmstart的配置tf.estimator.WarmStartSettings,然后将这个配置传给NPUEstimator中warm_start_from参数即可。

示例代码

定制WarmStart的使用示例如下:

示例1:模型从warm_start路径中加载稀疏表user_table。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import tensorflow as tf
from tf_adapter import NPUEstimator
 
warm_settings=tf.estimator.WarmStartSettings(ckpt_to_initialize_from="./warm_start",vars_to_warm_start ="user_table", var_name_to_prev_var_name=None)
est = NPUEstimator(
        model_fn=get_model_fn(create_fs_params, cfg, access_and_evict),
        params=params,
        model_dir=params.model_dir,
        config=run_config,
        warm_start_from=warm_settings 
)
表1 tf.estimator.WarmStartSettings参数说明

参数

参数类型

说明

ckpt_to_initialize_from

str(path)

指定从哪个检查点开始初始化。

vars_to_warm_start

str/正则表达式/list[str]/list[variables]

指定从哪些变量开始初始化。

dense层变量的指定方式保持与tf原生一致;embedding参数支持:正则表达式、str(表名)和list(表名list)。

var_name_to_vocab_info

dict

指定词汇表信息,用于恢复嵌入矩阵。

var_name_to_prev_var_name

dict

用于存储变量名到warmstart路径中变量名的映射关系。

说明:

embedding表名目前不支持名称映射。

示例2:模型从warm_start_1路径中加载所有参数,然后模型从warm_star_2t路径中加载embedding表user_table、item_table,替代已经加载了的warm_start1路径中的稀疏表结果。模型从warm_start_3路径中加载mlp_layer_w参数,替代warm_start_1的加载结果。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import tensorflow as tf
from tf_adapter import NPUEstimator
 
ckpt_to_initialize_from_list = ["./warm_start_1", "./warm_start_2", "./warm_start_3"]
vars_to_warm_start_list=[".*",  ["user_table", "item_table"], "mlp_layer_w" ]
var_name_to_prev_var_name_list = [{}, {}, {}]
warm_settings=tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=ckpt_to_initialize_from_list,
        vars_to_warm_start = vars_to_warm_start_list,
        var_name_to_prev_var_name=var_name_to_prev_var_name_list )
 
 est = NPUEstimator(
        model_fn=get_model_fn(create_fs_params, config, access_and_evict),
        params=params,
        model_dir=params.model_dir,
        config=run_config
        warm_start_from=warm_settings
    )
表2 支持多路径warmstart功能中tf.estimator.WarmStartSettings参数说明

参数

参数类型

说明

ckpt_to_initialize_from

List(str(path))

指定从哪个检查点开始初始化

vars_to_warm_start

List(str/正则表达式/list[str]/list[variables])

指定从哪些变量开始初始化。

dense层变量的指定方式保持与tf原生一致;embedding参数支持:正则表达式、str(表名)和list(表名list)。

var_name_to_vocab_info

List(dict)

指定词汇表信息,用于恢复嵌入矩阵。

var_name_to_prev_var_name

List(dict)

用于存储变量名到warmstart路径中变量名的映射关系。

embedding表名目前不支持名称映射。