获取初始化iterator的TensorFlow算子(tf.Operation),该算子需要通过sess.run()使用。
from mx_rec.util.initialize import ConfigInitializer ConfigInitializer.get_instance().train_params_config.get_initializer(is_training)
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
is_training |
bool |
必选 |
是否为训练模式。
|
import tensorflow as tf from mx_rec.util.initialize import ConfigInitializer from mx_rec.graph.modifier import modify_graph_and_start_emb_cache # train,需要开启自动改图 # train模式下,自动改图需要在计算梯度之后 计算梯度........ modify_graph_and_start_emb_cache(dump_graph=True) with tf.compat.v1.Session() as sess: # 请确保已调用过modify_graph_and_start_emb_cache()接口 initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True) sess.run(initializer)
接口调用流程及示例,参见自动改图。