mindietorch.compile

函数功能

函数原型

def compile(module: Any, ir="default", inputs=None, **kwargs)

约束说明

mindietorch.compile接口由于存在参数校验,在输入非法数据时,可能会抛出异常。故用户必须在try/except语句块内进行调用以及异常处理,防止在使用的过程中出现异常抛出导致程序退出的情况。

参数说明

参数名称

参数类型

参数说明

是否必选

module

torch.jit.ScriptModule或torch.nn.Module或torch.fx.GraphModule或torch.export.ExportedProgram

编译优化前的PyTorch模型。当ir设置为不同值时,module传入的类型限制不同,具体请参见ir字段的参数说明。

默认值:无

inputs

List[torch.Tensor]或List[List[torch.Tensor]]或List[mindietorch.Input]或List[List[mindietorch.Input]]

模型输入。

  • 若传入的是List[mindietorch.Input]类型数据,则表示静态或者动态ShapeRange输入;
  • 若传入的是List[List[mindietorch.Input]]类型数据,则表示动态分档输入;
  • 若传入的是List[torch.Tensor]或List[List[torch.Tensor]]类型数据,则会在内部转换为表示静态输入的List[mindietorch.Input]类型数据或表示动态分档输入的List[List[mindietorch.Input]]类型数据。
说明:
  • 当ir参数设置为torch_compile时,inputs不需要传入,其他情况下则必须传入合法类型的数据;
  • 当ir参数设置为dynamo时,inputs不支持传入,表示动态分档的List[List[mindietorch.Input]]类型数据。

默认值:None

precision_policy

mindietorch.PrecisionPolicy

设置模型的推理精度策略,支持FP16精度、FP32精度以及混合精度PREF_FP32、PREF_FP16。

默认值:PREF_FP32

truncate_long_and_double

bool

是否允许long和double类型转换。

默认值:True

require_full_compilation

bool

是否强制要求整图编译,若模型中存在不支持的算子,开启此项时会在编译模型时抛出异常并提示用户无法编译整图。

默认值:False

allow_tensor_replace_int

bool

是否允许采用Tensor代替Int。

默认值:False

min_block_size

int

切分子图的最少节点数量,取值范围[0,1024]。

默认值:1

default_buffer_size_vec

List[Int]

模型输出在编译阶段无法确定shape时的默认分配内存大小,一般用于动态模型,取值范围(0,36864],单位为MB。

支持List的长度等于1或者输出的个数,若等于1则所有输出的默认内存大小均为该值,若等于输出个数则为每个输出单独设置默认内存大小。

默认值:[500, ]

torch_executed_ops

List[str] 或 List[torch._ops.OpOverload]

强制fallback执行的算子。

以add算子为例:

  • TorchScript路线:以libtorch风格的字符串指定算子,即torch_executed_ops=["aten::add"];
  • torch.export或者torch.compile路线:以PyTorch中的OpOverload对象指定算子,即torch_executed_ops=[torch.ops.aten.add.default]。

默认值:[]

soc_version

str

芯片型号。

默认值:Ascend310xxx

optimization_level

int

模型优化等级,取值如下:

  • 0:表示不优化;
  • 1:表示图优化;
  • 2:表示算子优化。

默认值:0

ir

str

模型编译方式,取值如下:

  • torchscript或者ts:表示编译TorchScript路线模型;此时仅支持module传入torch.jit.ScriptModule;
  • dynamo:表示编译torch.export路线模型,此时支持module传入torch.nn.Module或torch.export.ExportedProgram;
  • torch_compile:表示编译torch.compile路线模型。

默认值:default

根据module的类型进行匹配,优先编译TorchScript路线模型,其次编译torch.export路线模型,最后才是编译torch.compile路线模型。

enable_dynamic_torch_compile

bool

是否让torch.compile捕获动态的模型,仅在编译torch.compile路线模型时生效。

默认值:False

注:参数相关约束请参见compile

返回值说明

编译优化后的TorchScript模型或nn.Module。

mindietorch.compile接口由于存在参数校验,在输入非法数据时,可能会抛出异常。故用户必须在try/except语句块内进行调用以及异常处理,防止在使用的过程中出现异常抛出导致程序退出的情况。