在大模型推理场景下,如果存在Ref类算子(例如Assign、ScatterUpdate等算子,类似于PyTorch中的inplace类算子)改写输入内存的情况,可以在构图过程中将用户输入的Data类型转换为RefData类型,以减少重复数据拷贝,提高模型执行效率。
import torch_npu import torchair as tng config = tng.CompilerConfig() # 使能Ref_Data类型的开关 config.experimental_config.enable_ref_data = True npu_backend = tng.get_npu_backend(compiler_config=config) ... model = Model() model = torch.compile(model, backend=npu_backend, dynamic=False)
以在线推理场景为样例,示例代码如下:
import torch import torch_npu import torchair as tng from torch import nn from torchair.configs.compiler_config import CompilerConfig class Network(nn.Module): def __init__(self): super(Network, self).__init__() def forward(self, x): return x.add_(1) device = torch.device("npu:0") config = CompilerConfig() config.experimental_config.enable_ref_data = True input0 = torch.ones((3,3), dtype=torch.float32) input0 = input0.to(device) model = Network() npu_backend = tng.get_npu_backend(compiler_config=config) model = torch.compile(model, fullgraph=True, backend=npu_backend, dynamic=True)
使能RefData数据类型转换功能后,开启TorchAir python层日志,会显示如下屏显信息:
[DEBUG] TORCHAIR 20240607 02:06:15 Replace RefData_5_3_20_20_1200_400_20_1_0_140251860631280:RefData with arg0_1:Data in graph graph_1