完成训练并查看结果

关键步骤操作参考如下。

  1. 查看训练结果:
    • 如果要将稀疏表数据导出npy格式,可以调用export接口。
    • 如果要导出pb模型文件,可以调用estimator的export_saved_model接口,示例如下:
      import os
      import tensorflow as tf
      if tf.__version__.startswith("1"):
          from npu_bridge.npu_init import NPURunConfig, NPUEstimator
      else:
          from npu_device.compat.v1.npu_init import NPURunConfig, NPUEstimator
      # 可参见Estimator迁移章节的“运行配置”和“创建Estimator对象”
      run_config = NPURunConfig(...)
      est = NPUEstimator(...)
      # 通常在调用完train或train_and_evaluate之后调用export_saved_model接口
      def _serving_input_fn():
          # 根据具体业务模型进行调整,下面以little demo estimator模型的输入为例
          inputs = {
              "user_ids": tf.compat.v1.placeholder(shape=(None, 32), dtype=tf.int64, name="user_ids"),
              "item_ids": tf.compat.v1.placeholder(shape=(None, 8), dtype=tf.int64, name="item_ids"),
              "label_0": tf.compat.v1.placeholder(shape=(None,), dtype=tf.float32, name="label_0"),
              "label_1": tf.compat.v1.placeholder(shape=(None,), dtype=tf.float32, name="label_1"),
          }
          return tf.estimator.export.ServingInputReceiver(features=inputs, receiver_tensors=inputs)
      target_pb_path = os.path.abspath("pb_model_path")
      # 调用estimator的export_saved_model接口进行pb保存
      export_path = est.export_saved_model(target_pb_path, _serving_input_fn).decode("utf-8")
      print(f"The export saved model path is {export_path}.")
  2. 调用terminate_config_initializer接口关闭数据流释放资源。