Estimator API属于TensorFlow的高阶API,在2018年发布的TensorFlow 1.10版本中引入,它可极大简化机器学习的编程过程。Estimator有很多优势,例如:对分布式的良好支持、简化了模型的创建工作、有利于模型开发者之间的代码分享等。
TensorFlow 2.6版本中继续支持该高阶API,如果需要沿用在TensorFlow1.x版本中的用法,则可以通过compat.v1模块调用,调用方式如下:
tf.compat.v1.estimator.Estimator
使用compat.v1的Estimator进行训练脚本开发的流程为:
下面介绍如何迁移此类Estimator训练脚本,以便在昇腾AI处理器上进行训练。
对于以下步骤中涉及修改的python文件,新增以下头文件引用,用于导入NPU相关库。
import npu_device from npu_device.compat.v1.npu_init import * npu_device.compat.enable_v1()
一般情况下,此部分代码无需改造。如下情况需要进行适配修改:
dataset = dataset.batch(batch_size, drop_remainder=True)
这可能会丢弃文件中的最后几个样本,以确保每个批量都具有静态形状 (batch_size)。但需要注意的是:推理时,当最后一次迭代的推理数据量小于batch size时,需要补齐空白数据到batch size,因为有些脚本最后会加个断言,验证结果的数量要和验证数据的数量一致,此种情况会导致训练失败。
assert num_written_lines == num_actual_predict_examples
一般情况下,此部分代码无需改造。如下情况需要进行适配修改:
def gelu(x): cdf = 0.5 * (1.0 + tf.tanh( (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) return x*cdf layers = gelu()
迁移后的代码:
layers = npu_unary_ops.gelu(x)
TensorFlow通过Runconfig配置运行参数,用户需要按照如下示例,更改config相关配置。
session_config=tf.compat.v1.ConfigProto(allow_soft_placement=True,log_device_placement=False)) config=tf.estimator.RunConfig( session_config=session_config, model_dir=FLAGS.model_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps,
迁移后的代码:
session_config=tf.compat.v1.ConfigProto(allow_soft_placement=True,log_device_placement=False)) custom_op = sess_config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" sess_config.graph_options.rewrite_options.remapping = rewriter_config_pb2.RewriterConfig.OFF npu_config=NPURunConfig( session_config=sess_config, model_dir=FLAGS.model_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps,
Estimator的迁移只需要更改其config参数为上述npu_config,并将TensorFlow的Estimator迁移为NPUEstimator。
TensorFlow原始代码:
mnist_classifier=tf.compat.v1.estimator.Estimator( model_fn=cnn_model_fn, config=config, model_dir="/tmp/mnist_convnet_model")
迁移后的代码:
mnist_classifier=NPUEstimator( model_fn=cnn_model_fn, config=npu_config, model_dir="/tmp/mnist_convnet_model" )
mnist_classifier.train( input_fn=train_input_fn, steps=20000, hooks=[logging_hook])