auto_decomposition

Function Usage

(Optional) Perform tensor decomposition on the input PyTorch model object to obtain the model object after decomposition and the names of the layers before and after decomposition, and save the decomposition information file.

Constraints

  • The input model must be an object of the torch.nn.Module type.
  • This API function decomposes only the convolution constructed by using torch.nn.Conv2d ().
  • This API automatically decomposes the convolutional layers that meet the decomposition conditions. For details about the conditions, see Restrictions.

Prototype

model, changes = auto_decomposition(model, decompose_info_path=None)

Command-Line Options

Option

Input/Return

Description

Restriction

model

Input

PyTorch model object that contains pre-trained weights to be decomposed. When calling this API, you are advised to place the model on the CPU instead of the GPU to prevent insufficient GPU memory during decomposition.

A torch.nn.Module.

decompose_info_path

Input (optional)

Path for storing the decomposition information file. The file is stored in JSON format. Therefore, the .json file name extension is recommended. If the value is None, the decomposition information file is not saved (default).

A string

model

Returns

Model object after tensor decomposition.

A torch.nn.Module.

changes

Return

Dictionary consisting of the layer names before and after tensor decomposition, for example, {'conv1': ['conv1.0', 'conv1.1'], 'conv2': ['conv2.0', 'conv2.1'], ...}.

A dict.

Returns

Dictionary consisting of the model object after tensor decomposition and the corresponding layer names before and after tensor decomposition.

Outputs

decompose_info_path: decomposition information file, which is stored in JSON format.

Example

1
2
3
4
5
6
7
from amct_pytorch.tensor_decompose import auto_decomposition
net = Net()                                                                 # (*) Build a model object.
net.load_state_dict(torch.load("src_path/weights.pth"))                     # (*) Load model weights.
Tensor decomposition of net, changes = auto_decomposition( #
    model=net,
    decompose_info_path="decomposed_path/decompose_info.json"
)
  1. If training is involved, this API must be called before the model parameters are passed to the optimizer; if DDP is used, this API must be called before the model parameters are passed to the DDP.
  2. This API modifies the input model object in place. That is, the model object input by the user is changed after decomposition (exception: The input model is a torch.nn.Conv2d object. In this case, this API does not modify it. If decomposition occurs, the newly constructed torch.nn.Module object is returned.