如需在Rec SDK训练框架使用定制WarmStart功能,需要在模型代码中创建NPUEstimator对象之前创建一个warmstart的配置tf.estimator.WarmStartSettings,然后将这个配置传给NPUEstimator中warm_start_from参数即可。
定制WarmStart的使用示例如下:
1 2 3 4 5 6 7 8 9 10 11 | import tensorflow as tf from tf_adapter import NPUEstimator warm_settings=tf.estimator.WarmStartSettings(ckpt_to_initialize_from="./warm_start",vars_to_warm_start ="user_table", var_name_to_prev_var_name=None) est = NPUEstimator( model_fn=get_model_fn(create_fs_params, cfg, access_and_evict), params=params, model_dir=params.model_dir, config=run_config, warm_start_from=warm_settings ) |
参数 |
参数类型 |
说明 |
---|---|---|
ckpt_to_initialize_from |
str(path) |
指定从哪个检查点开始初始化。 |
vars_to_warm_start |
str/正则表达式/list[str]/list[variables] |
指定从哪些变量开始初始化。 dense层变量的指定方式保持与tf原生一致;embedding参数支持:正则表达式、str(表名)和list(表名list)。 |
var_name_to_vocab_info |
dict |
指定词汇表信息,用于恢复嵌入矩阵。 |
var_name_to_prev_var_name |
dict |
用于存储变量名到warmstart路径中变量名的映射关系。 说明:
embedding表名目前不支持名称映射。 |
示例2:模型从warm_start_1路径中加载所有参数,然后模型从warm_star_2t路径中加载embedding表user_table、item_table,替代已经加载了的warm_start1路径中的稀疏表结果。模型从warm_start_3路径中加载mlp_layer_w参数,替代warm_start_1的加载结果。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import tensorflow as tf from tf_adapter import NPUEstimator ckpt_to_initialize_from_list = ["./warm_start_1", "./warm_start_2", "./warm_start_3"] vars_to_warm_start_list=[".*", ["user_table", "item_table"], "mlp_layer_w" ] var_name_to_prev_var_name_list = [{}, {}, {}] warm_settings=tf.estimator.WarmStartSettings( ckpt_to_initialize_from=ckpt_to_initialize_from_list, vars_to_warm_start = vars_to_warm_start_list, var_name_to_prev_var_name=var_name_to_prev_var_name_list ) est = NPUEstimator( model_fn=get_model_fn(create_fs_params, config, access_and_evict), params=params, model_dir=params.model_dir, config=run_config warm_start_from=warm_settings ) |
参数 |
参数类型 |
说明 |
---|---|---|
ckpt_to_initialize_from |
List(str(path)) |
指定从哪个检查点开始初始化 |
vars_to_warm_start |
List(str/正则表达式/list[str]/list[variables]) |
指定从哪些变量开始初始化。 dense层变量的指定方式保持与tf原生一致;embedding参数支持:正则表达式、str(表名)和list(表名list)。 |
var_name_to_vocab_info |
List(dict) |
指定词汇表信息,用于恢复嵌入矩阵。 |
var_name_to_prev_var_name |
List(dict) |
用于存储变量名到warmstart路径中变量名的映射关系。 embedding表名目前不支持名称映射。 |