torch_npu.npu_convert_weight_to_int4pack

功能描述

将数据类型为int32的输入tensor打包为int4存放,每8个int4数据通过一个int32数据承载,并进行交叠排放。

接口原型

torch_npu.npu_convert_weight_to_int4pack(Tensor weight, int inner_k_tiles=0) -> Tensor

参数说明

输出说明

输出为Tensor类型,代表int4打包后的输出,数据类型为INT32,shape为(k, n/8), (n, k/8),数据格式支持ND。

约束说明

无。

支持的PyTorch版本

支持的型号

Atlas A2 训练系列产品

调用示例

import torch
import torch_npu

m = 128
k = 64
n = 32
trans_weight = False

cpu_x = torch.randn((m, k), dtype=torch.float16)
if trans_weight:
    cpu_weight = torch.randint(low=-8, high=8, size=(n, k), dtype=torch.int32)
    cpu_antiquantscale = torch.randn((n, 1), dtype=torch.float16)
    cpu_antiquantoffset = torch.randn((n, 1), dtype=torch.float16)
else:
    cpu_weight = torch.randint(low=-8, high=8, size=(k, n), dtype=torch.int32)
    cpu_antiquantscale = torch.randn((1, n), dtype=torch.float16)
    cpu_antiquantoffset = torch.randn((1, n), dtype=torch.float16)

weight_int4 = torch_npu.npu_convert_weight_to_int4pack(cpu_weight.npu())

if trans_weight:
    cpu_weight = cpu_weight.transpose(-1, -2)
    weight_int4 = weight_int4.transpose(-1, -2)
    cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
    cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2)

npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), weight_int4.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())