训练流程

Estimator简介

Estimator API属于TensorFlow的高阶API,在2018年发布的TensorFlow 1.10版本中引入,它可极大简化机器学习的编程过程。Estimator有很多优势,例如:对分布式的良好支持、简化了模型的创建工作、有利于模型开发者之间的代码分享等。

使用Estimator进行训练脚本开发的流程为:

表1 训练流程说明

过程

描述

数据预处理

创建输入函数input_fn。

模型构建

构建模型函数model_fn。

运行配置

实例化Estimator,并传入Runconfig类对象作为运行参数。

执行训练

在Estimator上调用训练方法Estimator.train(),利用指定输入对模型进行固定步数的训练。