Procedure

This section describes the API call sequence and example of tensor decomposition.

API Call Sequence

Figure 1 shows the API call sequence. For the decomposition example, see Sample List.

Figure 1 API call sequence for tensor decomposition
  1. Online decomposition process

    Prepare a torch.nn.Module model object that contains pre-trained weights. Before passing the model parameters to the optimizer, pass the model object to the auto_decomposition API for tensor decomposition. Then, directly fine-tune the decomposed model object.

  2. Offline decomposition process
    1. In any script, prepare a torch.nn.Module model object that contains pre-trained weights, and pass the model object and the decomposition information file save path to the auto_decomposition API to decompose the model object with the decomposition information file generated and save the decomposed weights.
    2. During fine-tuning, in the training script, before passing the model parameters to the optimizer, pass the model object and the decomposition information file save path obtained in 2.a to the decompose_network API. This API then modifies the model structure based on the said file and loads the decomposed weights saved in 2.a for fine-tuning.

    In offline decomposition, you need to save the decomposed weights yourself after making an auto_decomposition call in 2.a and load the decomposed weights yourself after making a decompose_network call in 2.b.

    The purpose of this design is to facilitate the user to freely control the access to the weight file, for example, to store the custom information in the weight file.

After tensor decomposition, one convolution is decomposed into two concatenated convolutions. shows the convolutions before and after decomposition.

Figure 2 Diagrams before and after convolution decomposition

Examples

In the following example, the asterisk (*) indicates the existing code, and the ellipsis (...) indicates the omitted code. The code is for reference only. The actual code may be different. Adjust the code based on the site requirements.

  1. Online tensor decomposition
    In the training script, call auto_decomposition to decompose the PyTorch model with pre-trained weights, and then fine-tune the decomposed model.
    1
    2
    3
    4
    5
    6
    from amct_pytorch.tensor_decompose import auto_decomposition
    net = Net() # (*): builds a model object.
    Weight of the net.load_state_dict(torch.load("src_path/net.pth")) # (*) loading model
    Tensor decomposition of net, changes = auto_decomposition(model = net) #
    optimizer = build_optimizer(net, ...) # (*): builds an optimizer (passing model parameters to the optimizer).
    train(net, optimizer, ...)                           # (*) finetune
    
  2. Offline tensor decomposition
    1. In any script, call auto_decomposition to decompose the PyTorch model with pre-trained weights and save the decomposition information file and the decomposed weights.
      1
      2
      3
      4
      5
      6
      7
      8
      from amct_pytorch.tensor_decompose import auto_decomposition
      net = Net() # (*): builds a model object.
      net.load_state_dict(torch.load("src_path/weights.pth"))                     # (*) Load model weights.
      net, changes = auto_decomposition(                                          # Perform tensor decomposition and save the decomposition information file.
          model=net,
      Path for storing the decompose_info_path="decomposed_path/decompose_info.json" # decomposition information file
      )
      torch.save(net.state_dict(), "decomposed_path/decomposed_weights.pth") #: stores the model weight after decomposition.
      
    2. In the training script, call decompose_network to modify the model structure as described in the decomposition information file obtained in 2.a, and load the decomposed weights saved in 2.a for fine-tuning.
      1
      2
      3
      4
      5
      6
      7
      8
      9
      from amct_pytorch.tensor_decompose import decompose_network
      net = Net()                                                                 # (*) Build a model object.
      net, changes = decompose_network(                                           # Load the decomposition information file to modify the model structure.
          model=net,
      decompose_info_path="decomposed_path/decompose_info.json" # Path of the breakdown information file saved in the previous step
      )
      Load the weight of the decomposed model saved in the previous step to net.load_state_dict(torch.load("decomposed_path/decomposed_weights.pth")) #.
      optimizer = build_optimizer(net, ...) # (*): builds an optimizer (passing model parameters to the optimizer).
      train(net, optimizer, ...)                                                  # (*) finetune