npu_grouped_matmul(x, weight, *, bias=None, scale=None, offset=None, antiquant_scale=None, antiquant_offset=None, per_token_scale=None, group_list=None, activation_input=None, activation_quant_scale=None, activation_quant_offset=None, split_item=0, group_type=-1, group_list_type=0, act_type=0, output_dtype=None) -> List[torch.Tensor]
场景 |
x |
weight |
bias |
scale |
antiquant_scale |
antiquant_offset |
output_dtype |
y |
---|---|---|---|---|---|---|---|---|
非量化 |
torch.float16 |
torch.float16 |
torch.float16 |
无需赋值 |
无需赋值 |
无需赋值 |
torch.float16 |
torch.float16 |
torch.bfloat16 |
torch.bfloat16 |
torch.float32 |
无需赋值 |
无需赋值 |
无需赋值 |
torch.bfloat16 |
torch.bfloat16 |
|
torch.float32 |
torch.float32 |
torch.float32 |
无需赋值 |
无需赋值 |
无需赋值 |
torch.float32 |
torch.float32 |
|
per-channel量化 |
torch.int8 |
torch.int8 |
torch.int32 |
torch.int64 |
无需赋值 |
无需赋值 |
torch.int8 |
torch.int8 |
伪量化 |
torch.float16 |
torch.int8 |
torch.float16 |
无需赋值 |
torch.float16 |
torch.float16 |
torch.float16 |
torch.float16 |
torch.bfloat16 |
torch.int8 |
torch.float32 |
无需赋值 |
torch.bfloat16 |
torch.bfloat16 |
torch.bfloat16 |
torch.bfloat16 |
场景 |
x |
weight |
bias |
scale |
antiquant_scale |
antiquant_offset |
per_token_scale |
output_dtype |
y |
---|---|---|---|---|---|---|---|---|---|
非量化 |
torch.float16 |
torch.float16 |
torch.float16 |
无需赋值 |
无需赋值 |
无需赋值 |
无需赋值 |
None/torch.float16 |
torch.float16 |
torch.bfloat16 |
torch.bfloat16 |
torch.float32 |
无需赋值 |
无需赋值 |
无需赋值 |
无需赋值 |
None/torch.bfloat16 |
torch.bfloat16 |
|
torch.float32 |
torch.float32 |
torch.float32 |
无需赋值 |
无需赋值 |
无需赋值 |
无需赋值 |
None/torch.float32(仅x/weight/y均为单张量) |
torch.float32 |
|
per-channel量化 |
torch.int8 |
torch.int8 |
torch.int32 |
torch.int64 |
无需赋值 |
无需赋值 |
无需赋值 |
None/torch.int8 |
torch.int8 |
torch.int8 |
torch.int8 |
torch.int32 |
torch.bfloat16 |
无需赋值 |
无需赋值 |
无需赋值 |
torch.bfloat16 |
torch.bfloat16 |
|
torch.int8 |
torch.int8 |
torch.int32 |
torch.float32 |
无需赋值 |
无需赋值 |
无需赋值 |
torch.float16 |
torch.float16 |
|
per-token量化 |
torch.int8 |
torch.int8 |
torch.int32 |
torch.bfloat16 |
无需赋值 |
无需赋值 |
torch.float32 |
torch.bfloat16 |
torch.bfloat16 |
torch.int8 |
torch.int8 |
torch.int32 |
torch.float32 |
无需赋值 |
无需赋值 |
torch.float32 |
torch.float16 |
torch.float16 |
|
伪量化 |
torch.float16 |
torch.int8/int4 |
torch.float16 |
无需赋值 |
torch.float16 |
torch.float16 |
无需赋值 |
None/torch.float16 |
torch.float16 |
torch.bfloat16 |
torch.int8/int4 |
torch.float32 |
无需赋值 |
torch.bfloat16 |
torch.bfloat16 |
无需赋值 |
None/torch.bfloat16 |
torch.bfloat16 |
x |
weight |
bias |
scale |
antiquant_scale |
antiquant_offset |
per_token_scale |
output_dtype |
y |
---|---|---|---|---|---|---|---|---|
torch.float16 |
torch.float16 |
torch.float16 |
无需赋值 |
无需赋值 |
无需赋值 |
torch.float32 |
torch.float16 |
torch.float16 |
支持场景 |
场景说明 |
场景限制 |
---|---|---|
多多多 |
x和weight为多张量,y为多张量。每组数据的张量是独立的。 |
|
单多单 |
x为单张量,weight为多张量,y为单张量。 |
|
单多多 |
x为单张量,weight为多张量,y为多张量。 |
|
多多单 |
x和weight为多张量,y为单张量。每组矩阵乘法的结果连续存放在同一个张量中。 |
|
group_type |
支持场景 |
场景说明 |
场景限制 |
---|---|---|---|
-1 |
多多多 |
x和weight为多张量,y为多张量。每组数据的张量是独立的。 |
|
0 |
单单单 |
x、weight与y均为单张量。 |
|
0 |
单多单 |
x为单张量,weight为多张量,y为单张量。 |
|
0 |
多多单 |
x和weight为多张量,y为单张量。每组矩阵乘法的结果连续存放在同一个张量中。 |
|
2 |
单单单 |
x、weight与y均为单张量 |
|
输入输出只支持float16的数据类型,输出y的n轴大小需要是16的倍数。
group_type |
支持场景 |
场景说明 |
场景限制 |
---|---|---|---|
0 |
单单单 |
x、weight与y均为单张量 |
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import torch import torch_npu x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] bias1 = torch.randn(256, device='npu', dtype=torch.float16) bias2 = torch.randn(1024, device='npu', dtype=torch.float16) bias3 = torch.randn(128, device='npu', dtype=torch.float16) bias = [bias1, bias2, bias3] group_list = None split_item = 0 npu_out = torch_npu.npu_grouped_matmul(x, weight, bias=bias, group_list=group_list, split_item=split_item) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch import torch.nn as nn import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class GMMModel(nn.Module): def __init__(self): super().__init__() def forward(self, x, weight): return torch_npu.npu_grouped_matmul(x, weight) def main(): x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] model = GMMModel().npu() model = torch.compile(model, backend=npu_backend, dynamic=False) custom_output = model(x, weight) if __name__ == '__main__': main() |