昇腾社区首页
中文
注册

mindietorch.compile

函数功能

将原始TorchScript模型/ExportedProgram进行编译优化生成可在昇腾处理上加速推理的TorchScript模型/nn.Module。

函数原型

def compile(module: Any, input sprecision_policy = _enums.PrecisionPolicy.PREF_FP32, truncate_long_and_double = True, require_full_compilation = False, allow_tensor_replace_int = False, min_block_size = 3, torch_executed_ops = [], soc_version = "Ascend310P3", optimization_level = 0) -> 
torch.nn.Module | torch.jit.ScriptModule

参数说明

参数名称

参数类型

参数说明

是否必选

module

torch.jit.ScriptModule 或 torch.nn.Module 或 torch.export.ExportedProgram

编译优化前的PyTorch模型。

默认值:无

inputs

List[torch.Tensor] 或 List[mindietorch.Input]或 List[[mindietorch.Input]]

模型输入。

默认值:无

precision_policy

Enum

设置模型的推理精度策略,支持混合精度PREF_FP32、FP16以及FP32精度。

默认值:PREF_FP32

truncate_long_and_double

Bool

是否允许long和double类型转换。

默认值:True

require_full_compilation

Bool

是否整图编译,仅在编译TorchScript模型时生效。

默认值:False

allow_tensor_replace_int

Bool

是否允许采用Tensor代替Int,,仅在编译TorchScript模型时生效。

默认值:False

min_block_size

Int

切分子图的最少节点数量,,仅在编译TorchScript模型时生效。

默认值:3

torch_executed_ops

List[String]

强制在torch上执行的算子,如:["aten::add"],仅在编译TorchScript模型时生效。

默认值:无

soc_version

String

芯片型号。

默认值:Ascend310P3

optimization_level

Int

模型优化等级,取值:0表示不优化,1表示图优化,2表示算子优化,仅在编译TorchScript模型时生效。

默认值:0

ir

Str

模型编译方式,取值:torchscript 或 ts表示编译TorchScript模型(此时仅支持module传入torch.jit.ScriptModule),dynamo表示编译ExportedProgram(此时支持module传入torch.nn.Module 或 torch.export.ExportedProgram)。

默认值:default,代表编译TorchScript模型。

注:参数相关约束请参见compile

返回值说明

编译优化后的torchscript模型或nn.Module。

mindietorch.compile接口由于存在参数校验,在输入非法数据时,可能会抛出异常。故用户必须在try/except语句块内进行调用以及异常处理,防止在使用的过程中出现异常抛出导致程序退出的情况。