import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
attr_dst_type = 2
attr_dst_type_torch = torch.qint8 if attr_dst_type == 2 else torch.quint4x2
x= torch.randn(6, 4).to(torch.float16).npu()
scale = torch.randn(4, 4).to(torch.float32).npu()
group_index = torch.tensor([1, 4, 6, 6], dtype=torch.int32).npu()
offset = torch.randn(1).to(torch.float32).npu()
class Network(torch.nn.Module):
def __init__(self):
super(Network, self).__init__()
def forward(self, x, scale, group_index, offset, dst_type):
return torch_npu.npu_group_quant(x, scale, group_index, offset=offset, dst_dtype=dst_type)
model = Network()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
config.debug.graph_dump.type = 'pbtxt'
model = torch.compile(model, fullgraph=True, backend=npu_backend, dynamic=True)
output_data = model(x, scale, group_index, offset=offset, dst_type=attr_dst_type_torch)