Porting with sess.run

About sess.run

As a low-level API of TensorFlow, sess.run appears more flexible than Estimator. On the flip side, using it for model implementation could be complex.

This API has been deprecated in TensorFlow 2.6. To use it in TensorFlow 2.6, call it using the compat.v1 module as follows:

1
tf.compat.v1.Session.run

Develop your training script with the sess.run API as follows:

  1. Preprocess data.
  2. Construct a model, compute the loss, and update the gradient.
  3. Create a session and initialize resources.
  4. Start training.

The following guides you through migrating your training script developed with sess.run, which after porting can run on the Ascend AI Processor.

Header File Inclusion

To import NPU-related libraries, add this header file reference in related Python files as follows:

1
2
3
import npu_device
from npu_device.compat.v1.npu_init import *
npu_device.compat.enable_v1()

Data Preprocessing

The code snippet is ready to use in normal cases. Manual tweaking is required only in the following scenario:

If the original network script relies on dataset.batch(batch_size) to return the dynamic shape, the shape of the last step on the network may be inconsistent with the previous shape because the number of remaining samples in the data flow may be less than the batch size. In this scenario, the dynamic shape compilation process starts. To improve network compilation performance, you are advised to set drop_remainder to True to discard the last several samples in the file and ensure that the shape of each step on the network is the same.
1
  dataset = dataset.batch(batch_size, drop_remainder=True)
Note that during inference, if the inference data volume of the last iteration is less than batch_size, you need to pad the inference data with blank data to batch_size. Failure to do so may lead to an assertion in your script that the number of validation results must be equal to the number of validation samples.
1
 assert num_written_lines == num_actual_predict_examples

Model Construction, Loss Computation, and Gradient Update

The code snippet is ready-to-use in normal cases. Manual tweaking is required only in the following scenarios:

  • If tf.device is used in the original network, delete the related code.
  • Replace gelu in the original network with the corresponding CANN API.
    Original TensorFlow code:
    1
    2
    3
    4
    5
    def gelu(x): 
      cdf = 0.5 * (1.0 + tf.tanh(
         (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 
      return x*cdf
    layers = gelu()
    

    Code after porting:

    1
    layers = npu_unary_ops.gelu(x)
    

Session Creation and Resource Initialization

When running your training script on the Ascend AI Processor by using sess.run, note the following configurations:

  • The following configuration option is deactivated by default and should remain deactivated:

    rewrite_options.disable_model_pruning

  • The following configuration options are activated by default and should remain activated:
    • rewrite_options.function_optimization
    • rewrite_options.constant_folding
    • rewrite_options.shape_optimization
    • rewrite_options.arithmetic_optimization
    • rewrite_options.loop_optimization
    • rewrite_options.dependency_optimization
    • rewrite_options.layout_optimizer
  • The following configuration option is enabled by default and should be disabled explicitly:
    • rewrite_options.remapping
    • rewrite_options.memory_optimization

Original TensorFlow code:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# Construct an iterator.
iterator=Iterator.from_structure(train_dataset.output_types,train_dataset.output_shapes) 
 
# Obtain the batch data.
next_batch=iterator.get_next() 
 
# Initialize the iterator.
training_init_op=iterator.make_initializer(train_dataset) 
  
# Initialize the variables.
init=tf.compat.v1.global_variables_initializer() 
sess=tf.compat.v1.Session() 
sess.run(init) 
  
# Get the number of training/validation steps per epoch.
train_batches_per_epoch=int(np.floor(train_size/batch_size))

Code after porting:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Construct an iterator.
iterator=Iterator.from_structure(train_dataset.output_types,train_dataset.output_shapes) 
 
# Obtain the batch data.
next_batch=iterator.get_next() 
 
# Initialize the iterator.
training_init_op=iterator.make_initializer(train_dataset) 
  
# Initialize the variables.
init=tf.compat.v1.global_variables_initializer() 
 
# Create a session.
config = tf.compat.v1.ConfigProto() 
custom_op = config.graph_options.rewrite_options.custom_optimizers.add() 
custom_op.name = "NpuOptimizer" 
config.graph_options.rewrite_options.remapping = RewriterConfig.OFF  # Must be disabled explicitly.
config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF  # Must be disabled explicitly.
sess = tf.compat.v1.Session(config=config) 
sess.run(init) 
  
# Get the number of training/validation steps per epoch.
train_batches_per_epoch=int(np.floor(train_size/batch_size))

The Ascend platform supports all native functions of tf.compat.v1.Session.

It also allows you to enable functions such as automatic mixed precision. For details, see the corresponding API description.

Training

The code snippet is ready to use. See the following example.

1
2
3
4
5
6
7
8
9
# Start epochs.
for epoch in range(num_epochs):
  ##Initialize iterator with the training dataset
  sess.run(training_init_op)
  for step in range(train_batches_per_epoch):  
    #get next batch of data
    img_batch,label_batch=sess.run(next_batch)
    #run the training op
    _,train_loss = sess.run([train_op, loss],feed_dict={x:img_batch,y_:label_batch,is_training:True})

However, you need an explicit call to sess.close() in your ported script if you create a session without a with block, for example, you define a session object as a class member.

1
2
3
sess = tf.compat.v1.Session(config=config)
sess.run(...)
sess.close()

That is because the destructor of Geop is called in the close method of tf.compat.v1.Session. If you use a with block that calls __exit__ to close the session automatically, there is no need to call sess.close().

1
2
with tf.compat.v1.Session(config=config) as sess:
    sess.run(...)

In other cases, for example, taking a session object as a user-defined class member, you should explicitly call sess.close() to exit the session.