Workflow

This section describes the layers supported by 2:4 structured sparsity, and API call sequence and example.

Due to hardware restrictions, the Atlas inference series products and Atlas training products do not support the 2:4 structured sparsity feature. Enabling this feature obtains few performance benefits.

AMCT supports retraining-based 2:4 structured sparsity. For the sparsity example, see Sample List. The layers that support this feature as well as their restrictions are listed as follows.

Table 1 Layers that support 2:4 structured sparsity as well as their restrictions

Supported Layer Type

Restriction

Remarks

MatMul

transpose_a=False

Weight data type: float32 or float64

-

Conv2D

Weight data type: float32 or float64

The weights are of type const and do not have dynamic inputs (such as placeholders).

Conv2DBackpropInput

-

API Call Sequence

The following figure shows the API call sequence of 2:4 structured sparsity. The training environment uses the CPU/GPU environment of the TensorFlow framework. Based on the inference script of the open-source framework, the AMCT API is called to compress the model. The compressed model needs to be converted into an offline model that adapts to the Ascend AI Processor using the ATC before it can be used for inference on the Ascend AI Processor.

Figure 1 API call sequence for 2:4 structured sparsity

The user implements the operations in blue, while those in gray are implemented by using AMCT APIs. Specifically, import the package to the source TensorFlow network inference code and call APIs where appropriate for sparsity.

The workflow goes through the following steps:
  1. Construct a training graph and call the create_prune_retrain_model API to modify the graph before it is sparsified based on the sparsity configuration file, that is, to insert the 2:4 structured sparsity operator into the graph.
  2. Train the model and save the trained parameters as a checkpoint file.
  3. Construct an inference graph and call the create_prune_retrain_model API to modify the graph before sparsity based on the configuration file, that is, to insert the 2:4 structured sparsity operator into the graph and generate a sparsity record file.
  4. Create a session to restore training parameters and freeze the inference graph into a .pb model.
  5. Call the save_prune_retrain_model API to export the .pb inference model with filters sparsified based on the record file and the frozen model file. During the process, the 2:4 structured sparsity operator is removed.

Example

  1. Take the following steps to get started. Update the sample code based on your situation.
  2. Tweak the arguments passed to AMCT API calls as required. QAT relies on the user training result. Ensure that a TensorFlow training script that yields satisfactory training accuracy is available.
  1. Import the AMCT package and set the log level.
    1
    2
    import amct_tensorflow as amct
    amct.set_logging_level(print_level='info', save_level='info')
    
  2. (Optional) Build a graph, read the trained parameters, and run inference on the graph in the TensorFlow environment to validate the inference script and environment setup. (Update the sample code based on your situation.)

    This step is recommended as it guarantees a properly functioning original model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.

    1
    user_test_evaluate_model(evaluate_model, test_data)
    
  3. Build a training graph. (Update the sample code based on your situation.)
    1
    train_graph = user_load_train_graph()
    
  4. Run AMCT to perform training with the sparsity operator.
    1. Insert the sparsity operator into the training graph.

      Construct a training graph (that is, set is_training to True for the BN), and then call the create_prune_retrain_model API (corresponding to 1 in Figure 1) to modify the graph before sparsity based on the sparsity configuration file (corresponding to 2 in Figure 1). The create_prune_retrain_model API inserts a 2:4 structured sparsity operator into the graph to achieve fake-sparsity for inference. For details about how to construct a sparsity configuration file, see Simplified QAT Configuration File.

      1
      2
      3
      4
      5
      6
      record_file = './tmp/record.txt'
      simple_cfg = './retrain.cfg'
      amct.create_prune_retrain_model(graph=train_graph,
                                      outputs=user_model_outputs,
      			        record_file=record_file,
                                      config_defination=simple_cfg)
      
    2. Implement gradient descent optimization on the modified graph and train the graph on the training dataset. (Update the sample code based on your situation.)
      1. Implement gradient descent optimization on the modified graph.
        Call the adaptive learning rate optimizer (for example, RMSPropOptimizer) to implement gradient descent optimization (perform this step after 1).
        1
        2
        3
        optimizer = tf.compat.v1.train.RMSPropOptimizer(
            ARGS.learning_rate, momentum=ARGS.momentum)
        train_op = optimizer.minimize(loss)
        
      2. Train the model.
        Create a session to train the model, save the trained parameters as a checkpoint file, and obtain the sparsity operator result (corresponding to 3 and 4 in Figure 1).
        1
        2
        3
        4
        5
        with tf.Session() as sess:
             sess.run(tf.compat.v1.global_variables_initializer())
             sess.run(outputs)
             # Save the trained parameters as a checkpoint file.
             saver_save.save(sess, retrain_ckpt, global_step=0)
        
  5. Build an inference graph. (Update the sample code based on your situation.)
    1
    test_graph = user_load_test_graph()
    
  6. Run AMCT to implement 2:4 structured sparsity.
    1. Insert the sparsity operator into the inference graph.
      Call the create_prune_retrain_model API to modify the inference graph (with is_training set to False for BN) before sparsity based on the configuration file (corresponding to 5 in Figure 1) and generate a sparsity record file (corresponding to 6 in Figure 1).
      1
      2
      3
      4
      5
      6
      record_file = './tmp/record.txt'
      simple_cfg = './retrain.cfg'
      amct.create_prune_retrain_model(graph=test_graph,
                                      outputs=user_model_outputs,
      			        record_file=record_file,
                                      config_defination=simple_cfg)
      
    2. Restore the checkpoint weights trained in 4 and save it as a .pb model.
      Create a session to restore training parameters and freeze the inference graph into a .pb model (corresponding to 7 and 8 in Figure 1).
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      variables_to_restore = tf.compat.v1.global_variables()
      saver_restore = tf.compat.v1.train.Saver(variables_to_restore)
      with tf.Session() as sess:
           sess.run(tf.compat.v1.global_variables_initializer())
           # Restore training parameters.
           saver_restore.restore(sess, retrain_ckpt)
           # Save the model as a .pb model.
           constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
                sess, eval_graph.as_graph_def(), [output.name[:-2] for output in outputs])
           with tf.io.gfile.GFile(masked_pb_path, 'wb') as f:
                f.write(constant_graph.SerializeToString())
      
    3. Save the model, delete the inserted structured sparsity operator, and implement sparsity.
      Export the .pb inference model which has undergone 2:4 structured sparsity based on the record file and frozen model (corresponding to 6 and 9 in Figure 1).
      1
      2
      3
      4
      5
      pruned_model_path = './result/user_model'
      amct.save_prune_retrain_model(pb_model=masked_pb_path,
                                    outputs=user_model_outputs,
                                    record_file=record_file,
                                    save_path=pruned_model_path)
      
  7. (Optional) Run inference on the sparsified model user_model_pruned.pb in the TensorFlow environment based on the test dataset to test the accuracy. (Update the sample code based on your situation.)
    Compare the accuracy of the fake-quantized model with that of the original model (see 2).
    1
    2
    pruned_model = './results/user_model_pruned.pb'
    user_do_inference(pruned_model, test_data)