产品 |
是否支持 |
---|---|
√ |
|
√ |
|
√ |
|
√ |
对用户输入的PyTorch模型对象进行张量分解,得到分解后的模型对象和分解前后层的对应名称,并保存分解信息文件(可选)。
1 | model, changes = auto_decomposition(model, decompose_info_path=None) |
参数名 |
输入/输出 |
使用限制 |
---|---|---|
model |
输入 |
含义:待分解的含有预训练权重的PyTorch模型对象。在调用该接口时建议将模型放置于CPU而不是GPU上,以防分解时显存不足。 数据类型:torch.nn.Module |
decompose_info_path |
输入 |
含义:分解信息文件保存路径。将以json格式存储,因此建议使用.json扩展名。为None时不保存分解信息文件(默认)。 数据类型:string 默认值:None |
1 2 3 4 5 6 7 | from amct_pytorch.tensor_decompose import auto_decomposition net = Net() # 构建用户模型对象 net.load_state_dict(torch.load("src_path/weights.pth")) # 加载模型权重 net, changes = auto_decomposition( # 执行张量分解 model=net, decompose_info_path="decomposed_path/decompose_info.json" ) |