Estimator API属于TensorFlow的高阶API,在2018年发布的TensorFlow 1.10版本中引入,它可极大简化机器学习的编程过程。Estimator有很多优势,例如:对分布式的良好支持、简化了模型的创建工作、有利于模型开发者之间的代码分享等。
使用Estimator进行训练脚本开发的流程为:
过程 |
描述 |
---|---|
数据预处理 |
创建输入函数input_fn。 |
模型构建 |
构建模型函数model_fn。 |
运行配置 |
实例化Estimator,并传入Runconfig类对象作为运行参数。 |
执行训练 |
在Estimator上调用训练方法Estimator.train(),利用指定输入对模型进行固定步数的训练。 |