cache_compile
功能说明
开启模型编译缓存功能时需要调用该接口实现模型编译缓存。
函数原型
def cache_compile(func, *, config: Optional[CompilerConfig] = None, backend: Optional[Any] = None, dynamic: bool = True, cache_dir: Optional[str] = None, global_rank: Optional[int] = None, tp_rank: Optional[int] = None, pp_rank: Optional[int] = None, ge_cache: bool = False, **kwargs) -> Callable
参数说明
参数 |
输入/输出 |
说明 |
是否必选 |
||
|---|---|---|---|---|---|
func |
输入 |
模型编译缓存的函数。 |
是 |
||
config |
输入 |
图编译配置,CompilerConfig类的实例化,缺省情况下采用TorchAir自动生成的配置。 说明:
本场景下不支持同时配置Dynamo导图功能、RefData类型转换功能。 |
否 |
||
backend |
输入 |
后端选择,默认值为"None",通过torchair.get_npu_backend()获取。
|
否 |
||
dynamic |
输入 |
是否按照输入动态trace,bool类型。 该参数继承了PyTorch原有特性,详细介绍请参考LINK。 默认True,进行动态trace。 |
否 |
||
cache_dir |
输入 |
缓存文件落盘的根目录,支持绝对路径和相对路径。
缺省时${cache_dir}为“.torchair_cache”(若无会新建),${work_dir}为当前工作目录,${model_info}为模型信息,${func}为封装的func函数。 说明:
|
否 |
||
global_rank |
输入 |
分布式训练时的rank,int类型。取值范围为[0, world_size-1],其中world_size是参与分布式训练的总进程数。 一般情况下TorchAir会自动通过torch.distributed.get_rank()获取缺省值。 |
否 |
||
tp_rank |
输入 |
指张量模型并行rank,int类型,取值是global_rank中划分为TP域的rank id。 |
否 |
||
pp_rank |
输入 |
指流水线并行rank,int类型,取值是global_rank中划分为PP域的rank id。 |
否 |
||
custom_decompositions |
输入 |
手动指定模型运行时使用的decomposition(将较大算子操作分解为小算子实现)。 用户根据实际情况配置,以Add算子为例示例代码如下:
|
否 |
||
ge_cache |
输入 |
是否缓存GE图编译结果,bool类型。
说明:
|
否 |
||
* |
输入 |
预留参数项,用于后续功能拓展。 |
否 |
返回值说明
返回一个Callable对象。
约束说明
- 如果图中包含依赖随机数生成器(RNG)的算子(例如randn、bernoulli、dropout等),不支持使用本功能。
- 该功能不支持同时配置Dynamo导图功能、RefData类型转换功能。
- 该功能跳过了Dynamo的JIT编译环节、Guards、Ascend IR图编译环节,与torch.compile原始方案相比多了如下限制:
- 缓存要与执行计算图一一对应,若重编译则缓存失效。
- Guards阶段被跳过且不会触发JIT编译,要求生成模型的脚本和加载缓存的脚本一致。
- CANN包跨版本缓存无法保证兼容性,如果版本升级,需要清理缓存目录并重新进行Ascend IR计算图编译生成缓存。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import dataclasses import logging from typing import List import torch import torch_npu import torchair from torchair import logger from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() logger.setLevel(logging.INFO) # InputMeta为仿照VLLM(Versatile Large Language Model)框架的入参结构 @dataclasses.dataclass class InputMeta: data: torch.Tensor is_prompt: bool class Model(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(2, 1) self.linear2 = torch.nn.Linear(2, 1) for param in self.parameters(): torch.nn.init.ones_(param) # 通过torchair.inference.cache_compile实现编译缓存 self.cached_prompt = torchair.inference.cache_compile(self.prompt, config=config) self.cached_decode = torchair.inference.cache_compile(self.decode, config=config) def forward(self, x: InputMeta, kv: List[torch.Tensor]): # 添加调用新函数的判断逻辑 if x.is_prompt: return self.cached_prompt(x, kv) return self.cached_decode(x, kv) def _forward(self, x, kv): return self.linear2(x.data) + self.linear2(kv[0]) # 重新封装为prompt函数 def prompt(self, x, y): return self._forward(x, y) # 重新封装为decode函数 def decode(self, x, y): return self._forward(x, y) x = InputMeta(data=torch.randn(2, 2).npu(), is_prompt=True) kv = [torch.randn(2, 2).npu()] model = Model().npu() # 注意无需调用torch.compile进行编译,直接执行model # 执行prompt res_prompt = model(x, kv) x.is_prompt = False # 执行decode res_decode = model(x, kv) |