获取tensorflow.data.Iterator的初始化算子(Operation),该算子需要通过使用sess.run()来初始化Iterator。
1 2 | from mx_rec.util.initialize import ConfigInitializer ConfigInitializer.get_instance().train_params_config.get_initializer(is_training) |
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
is_training |
bool |
必选 |
是否为训练模式。
|
1 2 3 4 5 6 7 8 9 10 11 | import tensorflow as tf from mx_rec.util.initialize import ConfigInitializer from mx_rec.graph.modifier import modify_graph_and_start_emb_cache # train,需要开启自动改图 # train模式下,自动改图需要在计算梯度之后 计算梯度........ modify_graph_and_start_emb_cache(dump_graph=True) with tf.compat.v1.Session() as sess: # 请确保已调用过modify_graph_and_start_emb_cache()接口 initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True) sess.run(initializer) |
接口调用流程及示例,参见自动改图。