create_prune_retrain_model

Applicability

Product

Supported

Atlas A3 training series products/Atlas A3 inference series products

  • Filter-level sparsity: √
  • 2:4 structured sparsity API: √

Atlas A2 training products/Atlas A2 inference products

  • Filter-level sparsity: √
  • 2:4 structured sparsity API: √

Atlas 200I/500 A2 inference product

  • Filter-level sparsity: √
  • 2:4 structured sparsity API: √

Atlas inference series products

  • Filter-level sparsity: √
  • 2:4 structured sparsity API: x

Atlas training products

  • Filter-level sparsity: √
  • 2:4 structured sparsity API: x

Note: For the Products marked with x, no error is reported when the API is called for the 2:4 structured sparsity feature, but the performance benefits cannot be obtained.

Description

Filter-level sparsity or 2:4 structured sparsity API. Only either of the two sparsity features can be enabled at a time.

  • Filter-level sparsity: Sparsifies a graph based on the sparsity configuration file input by the user. This function inserts the filter-level sparsity mask operator into the graph to achieve fake-sparsity for inference.
  • 2:4 structured sparsity: Sparsifies a graph based on the sparsity configuration file input by the user. This function inserts the 2:4 structured sparsity operator into the graph to achieve fake-sparsity for inference.

Prototype

1
create_prune_retrain_model(graph, outputs, record_file, config_defination)

Parameters

Parameter

Input/Output

Description

graph

Input

tf.Graph of the model for sparsity.

A tf.Graph.

outputs

Input

Output of the user model

A list of strings, for example, [output1,output2,...].

record_file

Input

Path (including the file name) of the file that records sparsity information. This file records the cascade correlations between filter-level sparsity nodes or 2:4 structured sparsity nodes.

A string.

config_defination

Input

Sparsity configuration file input by the user, which specifies the sparsity configuration of each layer in the tf.Graph.

The simplified configuration file prune.cfg is generated based on the retrain_config_tf.proto file. The *.proto file is stored in /amct_tensorflow/proto/ under the AMCT installation directory.

For details about the parameters in the *.proto file and the generated sparsity configuration file prune.cfg, see Simplified QAT Configuration File.

A string.

Returns

None

Example

1
amct.create_prune_retrain_model(graph, [operation_name_1, operation_name_2], './tmp/record.txt', './tmp/sample_prune.cfg')