mindietorch.compile

函数功能

函数原型

def compile(module: Any, ir="default", inputs = None, **kwargs)

约束说明

mindietorch.compile接口由于存在参数校验,在输入非法数据时,可能会抛出异常。故用户必须在try/except语句块内进行调用以及异常处理,防止在使用的过程中出现异常抛出导致程序退出的情况。

参数说明

参数名称

参数类型

参数说明

是否必选

module

torch.jit.ScriptModule 或 torch.nn.Module 或 torch.export.ExportedProgram

编译优化前的PyTorch模型。

默认值:无

inputs

List[torch.Tensor] 或 List[mindietorch.Input]或 List[[mindietorch.Input]]

模型输入。

默认值:None

precision_policy

PrecisionPolicy(Enum)

设置模型的推理精度策略,支持FP16精度、FP32精度以及混合精度PREF_FP32、PREF_FP16。

默认值:PREF_FP32

truncate_long_and_double

Bool

是否允许long和double类型转换。

默认值:True

require_full_compilation

Bool

是否强制要求整图编译,若模型中存在不支持的算子,开启此项时会在编译模型时抛出异常并提示用户无法编译整图。

默认值:False

allow_tensor_replace_int

Bool

是否允许采用Tensor代替Int。

默认值:False

min_block_size

Int

切分子图的最少节点数量。

默认值:1

torch_executed_ops

List[String]

强制fallback执行的算子。

以add算子为例:

  • TorchScript路线:以libtorch风格的字符串指定算子,即torch_executed_ops=["aten::add"];
  • torch.export或者torch.compile路线:以PyTorch中的OpOverload对象指定算子,即torch_executed_ops=[torch.ops.aten.add.default]。

默认值:[]

soc_version

String

芯片型号。

默认值:Ascend310P3

optimization_level

Int

模型优化等级,取值如下:

  • 0:表示不优化;
  • 1:表示图优化;
  • 2:表示算子优化。

默认值:0

ir

String

模型编译方式,取值如下:

  • torchscript或者ts:表示编译TorchScript路线模型;此时仅支持module传入torch.jit.ScriptModule;
  • dynamo:表示编译torch.export路线模型,此时支持module传入torch.nn.Module或torch.export.ExportedProgram;
  • torch_compile:表示编译torch.compile路线模型。

默认值:default

根据module的类型进行匹配,优先编译TorchScript路线模型,其次编译torch.export路线模型,最后才是编译torch.compile路线模型。

enable_dynamic_torch_compile

Bool

是否让torch.compile捕获动态的模型,仅在编译torch.compile路线模型时生效。

默认值:False

注:参数相关约束请参见compile

返回值说明

编译优化后的TorchScript模型或nn.Module。

mindietorch.compile接口由于存在参数校验,在输入非法数据时,可能会抛出异常。故用户必须在try/except语句块内进行调用以及异常处理,防止在使用的过程中出现异常抛出导致程序退出的情况。