Porting with Keras
About Keras
Similar to Estimator, Keras is another high-level API of TensorFlow. It constructs graphs efficiently, and provides APIs for training, evaluation, validation, and export.
TensorFlow 2.6 continues the support for the Keras API. To use it in the same way as in TF1, call it using the compat.v1 module as follows:
1 | tf.compat.v1.Session |
Develop your training script with the Keras API as follows:
- Preprocess data.
- Construct your model.
- Build your model.
- Train your model.
Currently, only training scripts compiled using TensorFlow Keras APIs are supported. Native Keras APIs are not supported.
The following describes how to port the Keras training scripts for training on the Ascend AI Processor.
Header File Inclusion
To import NPU-related libraries, add this header file reference in related Python files as follows:
1 2 3 | import npu_device from npu_device.compat.v1.npu_init import * npu_device.compat.enable_v1() |
Porting Configuration
If you are using a Keras training script, the script ported to the Ascend platform will lose support of certain features such as the dynamic learning rate. Therefore, you are not advised to port Keras scripts to the Ascend platform. To run a Keras script on the Ascend platform, you need to edit the script as follows:
To train your model on the Ascend AI Processor, create a TensorFlow session, register Keras, and add related configurations. When the training ends, close the session.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import tensorflow as tf import tensorflow.keras as keras from tensorflow.keras import backend as K from npu_device.compat.v1.npu_init import * sess_config = tf.compat.v1.ConfigProto() custom_op = sess_config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" sess_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF sess_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF sess = tf.compat.v1.Session(config=sess_config) K.set_session(sess) # Preprocess data. # Construct your model. # Build your model. # Train your model. sess.close() |