关键步骤操作参考如下。
import os import tensorflow as tf if tf.__version__.startswith("1"): from npu_bridge.npu_init import NPURunConfig, NPUEstimator else: from npu_device.compat.v1.npu_init import NPURunConfig, NPUEstimator # 可参见Estimator迁移章节的“运行配置”和“创建Estimator对象” run_config = NPURunConfig(...) est = NPUEstimator(...) # 通常在调用完train或train_and_evaluate之后调用export_saved_model接口 def _serving_input_fn(): # 根据具体业务模型进行调整,下面以little demo estimator模型的输入为例 inputs = { "user_ids": tf.compat.v1.placeholder(shape=(None, 32), dtype=tf.int64, name="user_ids"), "item_ids": tf.compat.v1.placeholder(shape=(None, 8), dtype=tf.int64, name="item_ids"), "label_0": tf.compat.v1.placeholder(shape=(None,), dtype=tf.float32, name="label_0"), "label_1": tf.compat.v1.placeholder(shape=(None,), dtype=tf.float32, name="label_1"), } return tf.estimator.export.ServingInputReceiver(features=inputs, receiver_tensors=inputs) target_pb_path = os.path.abspath("pb_model_path") # 调用estimator的export_saved_model接口进行pb保存 export_path = est.export_saved_model(target_pb_path, _serving_input_fn).decode("utf-8") print(f"The export saved model path is {export_path}.")