import torch
import torch_npu
import torchair
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
m = 16
k = 17
n = 72
trans_weight = False
is_weight_nz = False
cpu_x = torch.randn((m, k),dtype=torch.float16)
if trans_weight:
cpu_weight = torch.randint(low=-8, high=8, size=(n, k) ,dtype=torch.int32)
cpu_antiquantscale = torch.ones((n, 1),dtype=torch.float16)
cpu_antiquantoffset = torch.zeros((n, 1),dtype=torch.float16)
else:
cpu_weight = torch.randint(low=-8, high=8, size=(k, n) ,dtype=torch.int32)
cpu_antiquantscale = torch.ones((1, n),dtype=torch.float16)
cpu_antiquantoffset = torch.zeros((1, n),dtype=torch.float16)
npu_weight = cpu_weight.npu()
if is_weight_nz:
# nd to fractal_nz
npu_weight = torch_npu.npu_format_cast(npu_weight.npu(), 29)
# int32 to int4pack
weight_int4pack = torch_npu.npu_convert_weight_to_int4pack(npu_weight)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size):
if trans_weight:
weight = weight.transpose(-1, -2)
antiquant_scale = antiquant_scale.transpose(-1, -2)
antiquant_offset = antiquant_offset.transpose(-1, -2)
return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size)
cpu_model = MyModel()
model = cpu_model.npu()
model = torch.compile(cpu_model, backend=npu_backend, dynamic=True, fullgraph=True)
npu_out = model(cpu_x.npu(), weight_int4pack, cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)