Sparsity Process

This section describes the sparse layers supported by manual sparsity, and API call sequence and example.

AMCT supports retraining-based filter-level sparsity. For the sparsity example, see the filter-level sparsity at Additional Samples. The layers and restrictions that support filter-level sparsity are listed as follows.

Table 1 Layers that support filter-level sparsity as well as their restrictions

Supported Layer Type

Restriction

MatMul

transpose_a = False, transpose_b = True/False, adjoint_a = False, adjoint_b = False

Weight data type: float32 or float64

Conv2D

  • Weight data type: float32 or float64
  • The weights are of type const and do not have dynamic inputs (such as placeholders).

API Call Sequence

Figure API call sequence shows the API call sequence for filter-level sparsity.

Figure 1 API call sequence for filter-level sparsity

The user implements the operations in blue, while those in gray are implemented by using AMCT APIs. Specifically, import the package to the source TensorFlow network inference code and call APIs where appropriate for sparsity.

The main steps are as follows:
  1. Construct a training graph and call the create_prune_retrain_model API to modify the graph before it is sparsified based on the sparsity configuration file, that is, to insert the mask operator (sparsity operator) into the graph.
  2. Train and checkpoint the model.
  3. Construct an inference graph and call the create_prune_retrain_model API to modify the graph before sparsity based on the configuration file, that is, to insert the mask operator into the graph and generate a sparsity record file.
  4. Create a session to restore training parameters and solidify the inference graph into a PB model.
  5. Call the save_prune_retrain_model API to export the .pb inference model with filters pruned based on the record file and the frozen model file. During the process, the mask operator is removed.

Examples

  1. Take the following steps to get started. Update the sample code based on your situation.
  2. Tweak the arguments passed to AMCT API calls as required. QAT relies on the user training result. Ensure that a TensorFlow training script that yields satisfactory training accuracy is available.
  1. Import the AMCT package and set the log level.
    1
    2
    import amct_tensorflow as amct
    amct.set_logging_level(print_level='info', save_level='info')
    
  2. (Optional) Build a graph, read the trained parameters, and run inference on the graph in the TensorFlow environment to validate the inference script and environment setup. (Update the sample code based on your situation.)

    This step is recommended as it guarantees a properly functioning source model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.

    1
    user_test_evaluate_model(evaluate_model, test_data)
    
  3. Build a training graph. (Update the sample code based on your situation.)
    1
    train_graph = user_load_train_graph()
    
  4. Call AMCT to perform training with the sparsity operator.
    1. Insert the sparsity operator into the training graph.

      Construct a training graph (that is, set is_training to True for the BN), and then call the create_prune_retrain_model API (corresponding to 1 in Figure 1) to modify the graph before sparsity based on the sparsity configuration file (corresponding to 2 in Figure 1). The create_prune_retrain_model API inserts the mask operator into the graph to achieve fake-sparsity during inference. For details about how to construct a sparsity configuration file, see Simplified QAT Configuration File.

      1
      2
      3
      4
      5
      6
      record_file = './tmp/record.txt'
      simple_cfg = './retrain.cfg'
      amct.create_prune_retrain_model(graph=train_graph,
                                      outputs=user_model_outputs,
      			        record_file=record_file,
                                      config_defination=simple_cfg)
      

      tf.graph does not support direct channel deletion. If the specifications of an operator on the network are modified, TensorFlow verifies the graph structure. As a result, channels cannot be deleted. Therefore, the mask operator is added to TensorFlow to mask the impact of sparsified channels on network training.

    2. Implement gradient descent optimization on the modified graph and train the graph on the training dataset. (Update the sample code based on your situation.)
      1. Implement gradient descent optimization on the modified graph.
        Call the adaptive learning rate optimizer (for example, RMSPropOptimizer) to implement gradient descent optimization (perform this step after 1).
        1
        2
        3
        optimizer = tf.compat.v1.train.RMSPropOptimizer(
            ARGS.learning_rate, momentum=ARGS.momentum)
        train_op = optimizer.minimize(loss)
        
      2. Train the model.
        Create a session to train the model, save the trained parameters as a checkpoint file, and obtain the result of the sparse operator (corresponding to 3 and 4 in).
        1
        2
        3
        4
        5
        with tf.Session() as sess:
             sess.run(tf.compat.v1.global_variables_initializer())
             sess.run(outputs)
             # Save the trained parameters as a checkpoint file.
             saver_save.save(sess, retrain_ckpt, global_step=0)
        
  5. Build an inference graph. (Update the sample code based on your situation.)
    1
    test_graph = user_load_test_graph()
    
  6. Call AMCT to implement filter-level sparsity.
    1. Insert the sparsity operator into the inference graph.
      Call the create_prune_retrain_model API to modify the inference graph (with is_training set to False for BN) before sparsity based on the configuration file (corresponding to 5 in Figure API call sequence) and generate a sparsity record file (corresponding to 6 in Figure 1).
      1
      2
      3
      4
      5
      6
      record_file = './tmp/record.txt'
      simple_cfg = './retrain.cfg'
      amct.create_prune_retrain_model(graph=test_graph,
                                      outputs=user_model_outputs,
      			        record_file=record_file,
                                      config_defination=simple_cfg)
      
    2. Restore the checkpoint weights trained in 4 and save it as a PB model.
      Create a session, restore training parameters, and solidify the inference graph into a PB model (corresponding to 7 and 8 in).
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      variables_to_restore = tf.compat.v1.global_variables()
      saver_restore = tf.compat.v1.train.Saver(variables_to_restore)
      with tf.Session() as sess:
           sess.run(tf.compat.v1.global_variables_initializer())
           # Restore training parameters.
           saver_restore.restore(sess, retrain_ckpt)
           # Save the model as a PB model.
           constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
                sess, eval_graph.as_graph_def(), [output.name[:-2] for output in outputs])
           with tf.io.gfile.GFile(masked_pb_path, 'wb') as f:
                f.write(constant_graph.SerializeToString())
      
    3. Save the model, delete the sparsity operator, and implement filter-level sparsity.
      Export the PB inference model after channel cropping based on the record file and solidified model (corresponding to 6 and 9 in).
      1
      2
      3
      4
      5
      pruned_model_path = './result/user_model'
      amct.save_prune_retrain_model(pb_model=masked_pb_path,
                                    outputs=user_model_outputs,
                                    record_file=record_file,
                                    save_path=pruned_model_path)
      
  7. (Optional) Run inference on the fake-quantized model user_model_quantized.pb in the TensorFlow environment based on the test dataset to test the accuracy. (Update the sample code based on your situation.)
    Check the accuracy loss of the fake-quantized model by comparing with that of the source model (see 2).
    1
    2
    pruned_model = './results/user_model_pruned.pb'
    user_do_inference(pruned_model, test_data)