图内标定SuperKernel范围
功能简介
SuperKernel是一种算子二进制融合技术,与源码融合不同,它聚焦于内核函数(Kernel)的二进制调度方案优化,在已编译的二进制代码基础上融合创建一个超级Kernel函数(简称SuperKernel),以调用子函数的方式调用多个其它内核函数,达到优化计算任务、提升性能和资源利用率目的。
相对于单算子下发,SuperKernel技术可以优化任务调度的等待时间和调度开销,同时利用task间隙资源进一步优化算子头开销。
实现SuperKernel的原理如下:
- 通过SuperKernel融合策略识别可被融合的子图。
- 将子图内的算子按SuperKernel融合规则合并为一个大Kernel,在新Kernel内通过生成一段子Kernel调用代码将子图上所有Kernel入口函数完成一次调用,并基于图的依赖完成同步插入。
TorchAir提供标定SuperKernel范围的能力,支持用户根据实际业务需求对融合范围内的算子进行标记和优化配置。
使用约束
- 该功能仅适用于静态图场景。
- 该功能仅适用于
Atlas A3 训练系列产品/Atlas A3 推理系列产品 。 - 注意,SuperKernel融合会按网络中算子顺序依次识别能否被融合,若识别到不可融合的算子,则生成第一段SuperKernel,同时自动跳过该算子进行第二段SuperKernel融合。
使用方法
- 用户自行分析模型脚本中可被融合的算子。
- 标定SuperKernel范围。
使用如下with语句块(super_kernel),语句块内算子均被融合为一个超级Kernel进行计算。
1
with torchair.scope.super_kernel(scope: str, options: str = ''):
- scope:表示上下文算子被融合的SuperKernel名,相同的scope代表相同的范围,由用户控制。
- options:表示融合SuperKernel的编译选项,缺省情况下,系统编译模式采用所有编译选项(参见表1)的默认值。
同时支持用户自定义组合编译选项,配置格式形如"<option1>=<value1>:<option2>=<value2>:<option3>=......",多个选项时用英文冒号分割。
表1 编译选项说明 选项
功能说明
stream-fusion
SuperKernel内的多流配置,支持如下取值:
- 0(缺省值):表示SuperKernel内算子最多支持配置两条流。
- 1:表示SuperKernel内算子可设置的流数≥1。
本场景下标定范围内的算子资源共用,即不同流上的Cube和Vector算子并行执行,提高了运行效率。
说明:MicroBatch场景下,stream与核绑定且资源完全独立,建议用户使用图内多流表达功能配置多流,此时不推荐同时配置stream-fusion=1。
使用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | # 导入TorchAir框架 import torch import numpy as np import torch_npu import torchair from torchair.configs.compiler_config import CompilerConfig if __name__ == "__main__": config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = torchair.get_npu_backend(compiler_config=config) # 定义模型model class ModelOrigin(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2, scale, offset, bias, pertoken_scale, weight_scale): quant_matmul_res_origin = torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale, output_dtype=torch.bfloat16) swiglu_res_origin = torch_npu.npu_dequant_swiglu_quant(quant_matmul_res_origin, weight_scale=weight_scale) return swiglu_res_origin, quant_matmul_res_origin class ModelSuperKernel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2, scale, offset, bias, pertoken_scale, weight_scale): # 将 npu_quant_matmul和npu_dequant_swiglu_quant融合为superKernel,标记为sp1 with torchair.scope.super_kernel("sp1",""): quant_matmul_res_origin = torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale, output_dtype=torch.bfloat16) swiglu_res_origin = torch_npu.npu_dequant_swiglu_quant(quant_matmul_res_origin, weight_scale=weight_scale) return swiglu_res_origin, quant_matmul_res_origin m = 864 k = 7168 n = 4096 bias_flag = False cpu_x1 = torch.randint(-10, 10, (m, k), dtype=torch.int8) cpu_x2 = torch.randint(-10, 10, (n, k), dtype=torch.int8) cpu_x2 = torch_npu.npu_format_cast(cpu_x2.npu().transpose(1,0).contiguous(), 29) scale = torch.randn((n,), dtype=torch.float32) # print("scale:", scale) pertoken_scale = torch.randn((m,), dtype=torch.float32) # print("pertoken_scale:", pertoken_scale) bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16) weight_scale = torch.randn((m, n), dtype=torch.float32).npu() # 使用torchair图模式后端编译模型 model_no_sk = torch.compile(ModelOrigin(), backend=npu_backend, dynamic=False) print("-------------------- run no sk -----------------------------------") swiglu_res_origin,quant_matmul_res_origin = model_no_sk(cpu_x1.npu(), cpu_x2, scale.npu(), None, None, pertoken_scale.npu(), weight_scale) model_sk = torch.compile(ModelSuperKernel(), backend=npu_backend, dynamic=False) print("-------------------- run sk -----------------------------------") swiglu_res_sk,quant_matmul_res_sk = model_sk(cpu_x1.npu(), cpu_x2, scale.npu(), None, None, pertoken_scale.npu(), weight_scale) res = np.array_equal(swiglu_res_origin[0].cpu().numpy(), swiglu_res_sk[0].cpu().numpy()) res = res and np.array_equal(swiglu_res_origin[1].cpu().numpy(), swiglu_res_sk[1].cpu().numpy()) if res: print("Precision ====== Success!!!") else: print("Precision ====== Failed.") |
父主题: max-autotune模式功能