get_initializer

功能描述

获取tensorflow.data.Iterator的初始化算子(Operation),该算子需要通过使用sess.run()来初始化Iterator。

函数原型

1
2
from mx_rec.util.initialize import ConfigInitializer
ConfigInitializer.get_instance().train_params_config.get_initializer(is_training)

参数说明

参数名

类型

可选/必选

说明

is_training

bool

必选

是否为训练模式。

  • True:训练(train)模式。
  • False:评估(eval)模式。

返回值说明

使用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import tensorflow as tf
from mx_rec.util.initialize import ConfigInitializer
from mx_rec.graph.modifier import modify_graph_and_start_emb_cache
# train,需要开启自动改图
# train模式下,自动改图需要在计算梯度之后
计算梯度........
modify_graph_and_start_emb_cache(dump_graph=True)
with tf.compat.v1.Session() as sess:
    # 请确保已调用过modify_graph_and_start_emb_cache()接口
    initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True)
    sess.run(initializer)

参考资源

接口调用流程及示例,参见自动改图