decompose_network

Function Usage

Decomposes the input PyTorch model object based on the decomposition information file saved via the auto_decomposition call and returns the decomposed model object and the layer names before and after decomposition.

Constraints

  • The input model must be an object of the torch.nn.Module type.
  • This API function can be used to modify only the structure of the convolution constructed by using torch.nn.Conv2d ().
  • The structure of the input model must be the consistent with that in the generation of the decomposition information file by calling auto_decomposition.

Prototype

model, changes = decompose_network(model, decompose_info_path)

Command-Line Options

Option

Input/Return

Description

Restriction

model

Input

PyTorch model object to be decomposed. When calling this API, you are advised to place the model on the CPU instead of the GPU to prevent insufficient GPU memory during decomposition.

A torch.nn.Module.

decompose_info_path

Input

Path of the decomposition information file, which is obtained through the auto_decomposition call.

A string

model

Return

Model object that is changed to the structure after tensor decomposition.

A torch.nn.Module.

changes

Return

Dictionary consisting of the layer names before and after tensor decomposition, for example, {'conv1': ['conv1.0', 'conv1.1'], 'conv2': ['conv2.0', 'conv2.1'], ...}.

A dict.

Returns

Dictionary consisting of the model object after tensor decomposition and the corresponding layer names before and after tensor decomposition.

Outputs

None

Examples

1
2
3
4
5
6
from amct_pytorch.tensor_decompose import decompose_network
net = Net()                                                                 # (*) Build a model object.
net, changes = decompose_network(                                           # Load the decomposition information file to modify the model structure.
    model=net,
decompose_info_path="decomposed_path/decompose_info.json" #: path of the decomposition information file stored by auto_decomposition
)
  1. If training is involved, this API must be called before the model parameters are passed to the optimizer; if DDP is used, this API must be called before the model parameters are passed to the DDP.
  2. This API modifies the input model object in place. That is, the model object input by the user is changed after decomposition (exception: The input model is a torch.nn.Conv2d object. In this case, this API does not modify it. The returned decomposed model is the newly constructed torch.nn.Module object.
  3. This API only modifies the model structure and does not update the convolution weight after decomposition. The weight value is the default value built by torch.nn.Conv2d (). If fine-tuning is required, save the weight of the decomposed model after calling auto_decomposition, load the weight after calling this API, and then fine-tune the model.