TorchAir图模式配置示例如下,仅供参考,请根据实际情况修改自定义的脚本。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | # 必须先导torch_npu再导torchair import torch import torch_npu import torchair # Patch方式实现集合通信入图(可选) from torchair import patch_for_hcom patch_for_hcom() # 定义模型Model class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.add(x, y) # 实例化模型model model = Model().npu() # 获取TorchAir提供的默认npu backend,自行配置config功能 config = torchair.CompilerConfig() npu_backend = torchair.get_npu_backend(compiler_config=config) # 使用npu backend进行compile opt_model = torch.compile(model, backend=npu_backend) # 使用编译后的model去执行 x = torch.randn(2, 2).npu() y = torch.randn(2, 2).npu() opt_model(x, y) |
torch.compile为PyTorch原生接口,官网介绍参见LINK,接口原型如下:
1 | torch.compile(model=None, *, fullgraph=False, dynamic=None, backend='inductor', mode=None, options=None, disable=False) |
TorchAir通过torchair.get_npu_backend接口获取NPU图编译后端npu_backend,将其作为backend入参实现昇腾NPU图模式计算,此时torch.compile参数配置如表1所示。
参数名 |
参数说明 |
备注 |
---|---|---|
model |
入图部分的模型或者函数,必选参数。 |
- |
fullgraph |
bool类型,可选参数。是否捕获整图进行优化。
|
参数含义与原生PyTorch compile接口一致,单击LINK获取官网介绍。 |
dynamic |
bool类型或None,可选参数。是否使用动态图Trace。
|
|
backend |
后端选择,缺省值为"inductor",目前昇腾NPU暂不支持。 昇腾NPU成图只有一种后端,通过torchair.get_npu_backend接口获取,必选参数。 |
通过图编译后端compiler_config参数配置图模式功能,配置项参见compiler_config功能。 |
mode |
开销模式,内存开销模式选择,昇腾NPU暂不支持。缺省值为None。 |
- |
options |
优化选项,昇腾NPU暂不支持。缺省值为None。 |
- |
disable |
bool类型,可选参数。是否关闭torch.compile能力。
|
参数含义与原生PyTorch compile接口一致,单击LINK获取官网介绍。 |
TorchAir提供的NPU图编译后端npu_backend,支持如下功能配置,详细介绍参见CompilerConfig类。
成员名 |
功能说明 |
---|---|
debug |
配置debug调试类功能,配置形式为config.debug.xxx,包括如下功能: |
export |
配置离线导图相关功能,配置形式为config.export.xxx,具体介绍参见Dynamo导图功能。 |
dump_config |
配置图模式下数据dump功能,配置形式为config.dump_config.xxx,具体参见算子输入输出dump功能(图模式)。 |
fusion_config |
配置图融合相关功能,配置形式为config.fusion_config.xxx,具体参见算子融合规则配置功能。 |
experimental_config |
配置各种试验功能,配置形式为experimental_config.xxx,包括如下功能: |
inference_config |
配置推理场景相关功能,配置形式为config.inference_config.xxx,如动态shape图分档执行功能。 |
ge_config |
配置GE图相关功能,配置形式为config.ge_config.xxx,包括如下功能: |
mode |
配置相关的调度模式配置,配置形式为config.mode.xxx,如reduce-overhead执行模式(aclgraph)。 |