Sparsity Process
This section describes the layers supported by 2:4 structured sparsity, and API call sequence and example.
Due to hardware restrictions, the Atlas 200/300/500 Inference Product, and Atlas Training Series Product do not support the 2:4 structured sparsity feature. Enabling this feature obtains few performance benefits.
AMCT supports retraining-based 2:4 structured sparsity. For the sparsity example, see Sample List. The layers that support this feature are listed as follows.
Supported Layer Type |
Restriction |
Remarks |
|---|---|---|
MatMul |
transpose_a=False Weight data type: float32 or float64 |
- |
Conv2D |
Weight data type: float32 or float64 |
The weights are of type const and do not have dynamic inputs (such as placeholders). |
Conv2DBackpropInput |
- |
API Call Sequence
Figure API call sequence shows the API call sequence for 2:4 structured sparsity.
The user implements the operations in blue, while those in gray are implemented by using AMCT APIs. Specifically, import the package to the source TensorFlow network inference code and call APIs where appropriate for sparsity.
- Construct a training graph and call the create_prune_retrain_model API to modify the graph before it is sparsified based on the sparsity configuration file, that is, to insert the 2:4 structured sparsity operator into the graph.
- Train and checkpoint the model.
- Construct an inference graph and call the create_prune_retrain_model API to modify the graph before sparsity based on the configuration file, that is, to insert the 2:4 structured sparsity operator into the graph and generate a sparsity record file.
- Create a session to restore training parameters and solidify the inference graph into a PB model.
- Call the save_prune_retrain_model API to export the .pb inference model with filters pruned based on the record file and the frozen model file. During the process, the 2:4 structured sparsity operator is removed.
Examples
- Take the following steps to get started. Update the sample code based on your situation.
- Tweak the arguments passed to AMCT API calls as required. QAT relies on the user training result. Ensure that a TensorFlow training script that yields satisfactory training accuracy is available.
- Import the AMCT package and set the log level.
1 2
import amct_tensorflow as amct amct.set_logging_level(print_level='info', save_level='info')
- (Optional) Build a graph, read the trained parameters, and run inference on the graph in the TensorFlow environment to validate the inference script and environment setup. (Update the sample code based on your situation.)
This step is recommended as it guarantees a properly functioning source model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.
1user_test_evaluate_model(evaluate_model, test_data)
- Build a training graph. (Update the sample code based on your situation.)
1train_graph = user_load_train_graph()
- Call AMCT to perform training with the sparsity operator.
- Insert the sparsity operator into the training graph.
Construct a training graph (that is, set is_training to True for the BN), and then call the create_prune_retrain_model API (corresponding to 1 in Figure 1) to modify the graph before sparsity based on the sparsity configuration file (corresponding to 2 in Figure 1). The create_prune_retrain_model API inserts a 2:4 structured sparsity operator into the graph to achieve fake-sparsity for inference. For details about how to construct a sparsity configuration file, see Simplified QAT Configuration File.
1 2 3 4 5 6
record_file = './tmp/record.txt' simple_cfg = './retrain.cfg' amct.create_prune_retrain_model(graph=train_graph, outputs=user_model_outputs, record_file=record_file, config_defination=simple_cfg)
- Implement gradient descent optimization on the modified graph and train the graph on the training dataset. (Update the sample code based on your situation.)
- Implement gradient descent optimization on the modified graph.Call the adaptive learning rate optimizer (for example, RMSPropOptimizer) to implement gradient descent optimization (perform this step after 1).
1 2 3
optimizer = tf.compat.v1.train.RMSPropOptimizer( ARGS.learning_rate, momentum=ARGS.momentum) train_op = optimizer.minimize(loss)
- Train the model.Create a session to train the model, save the trained parameters as a checkpoint file, and obtain the result of the sparse operator (corresponding to 3 and 4 in).
1 2 3 4 5
with tf.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(outputs) # Save the trained parameters as a checkpoint file. saver_save.save(sess, retrain_ckpt, global_step=0)
- Implement gradient descent optimization on the modified graph.
- Insert the sparsity operator into the training graph.
- Build an inference graph. (Update the sample code based on your situation.)
1test_graph = user_load_test_graph()
- Call AMCT to implement 2:4 structured sparsity.
- Insert the sparsity operator into the inference graph.Call the create_prune_retrain_model API to modify the inference graph (with is_training set to False for BN) before sparsity based on the configuration file (corresponding to 5 in Figure API call sequence) and generate a sparsity record file (corresponding to 6 in Figure 1).
1 2 3 4 5 6
record_file = './tmp/record.txt' simple_cfg = './retrain.cfg' amct.create_prune_retrain_model(graph=test_graph, outputs=user_model_outputs, record_file=record_file, config_defination=simple_cfg)
- Restore the checkpoint weights trained in 4 and save it as a PB model.Create a session, restore training parameters, and solidify the inference graph into a PB model (corresponding to 7 and 8 in).
1 2 3 4 5 6 7 8 9 10 11
variables_to_restore = tf.compat.v1.global_variables() saver_restore = tf.compat.v1.train.Saver(variables_to_restore) with tf.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) # Restore training parameters. saver_restore.restore(sess, retrain_ckpt) # Save the model as a PB model. constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants( sess, eval_graph.as_graph_def(), [output.name[:-2] for output in outputs]) with tf.io.gfile.GFile(masked_pb_path, 'wb') as f: f.write(constant_graph.SerializeToString())
- Save the model and delete the inserted structured sparsity operator to implement sparsity.Export the .pb inference model which has undergone 2:4 structured sparsity based on the record file and frozen model (corresponding to 6 and 9 in Figure 1).
1 2 3 4 5
pruned_model_path = './result/user_model' amct.save_prune_retrain_model(pb_model=masked_pb_path, outputs=user_model_outputs, record_file=record_file, save_path=pruned_model_path)
- Insert the sparsity operator into the inference graph.
- (Optional) Run inference on the fake-quantized model user_model_quantized.pb in the TensorFlow environment based on the test dataset to test the accuracy. (Update the sample code based on your situation.)Check the accuracy loss of the fake-quantized model by comparing with that of the source model (see 2).
1 2
pruned_model = './results/user_model_pruned.pb' user_do_inference(pruned_model, test_data)
