使能RefData类型转换功能

功能简介

在大模型推理场景下,如果存在Ref类算子(例如Assign、ScatterUpdate等算子,类似于PyTorch中的inplace类算子)改写输入内存的情况,可以在构图过程中将用户输入的Data类型转换为RefData类型,以减少重复数据拷贝,提高模型执行效率。

使用方法

使用用例

以在线推理场景为样例,示例代码如下:

import torch
import torch_npu
import torchair as tng
from torch import nn
from torchair.configs.compiler_config import CompilerConfig
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
    def forward(self, x):
        return x.add_(1)

device = torch.device("npu:0")
config = CompilerConfig()
config.experimental_config.enable_ref_data = True
input0 = torch.ones((3,3), dtype=torch.float32)
input0 = input0.to(device)
model = Network()
npu_backend = tng.get_npu_backend(compiler_config=config)
model = torch.compile(model, fullgraph=True, backend=npu_backend, dynamic=True)

使能RefData数据类型转换功能后,开启TorchAir python层日志,会显示如下屏显信息:

[DEBUG] TORCHAIR 20240607 02:06:15 Replace RefData_5_3_20_20_1200_400_20_1_0_140251860631280:RefData with arg0_1:Data in graph graph_1