GroupedMatMul和MoeFinalizeRouting的融合算子,GroupedMatMul计算后的输出按照索引做combine动作。
1 | torch_npu.npu_grouped_matmul_finalize_routing(Tensor x, Tensor w, Tensor group_list, *, Tensor? scale=None, Tensor? bias=None, Tensor? pertoken_scale=None, Tensor? shared_input=None, Tensor? logit=None, Tensor? row_index=None, ScalarType? dtype=None, float? shared_input_weight=1.0, int shared_input_offset=0, int? output_bs=0, int? group_list_type=1) -> Tensor |
y: 一个2D的Tensor,不支持非连续的Tensor,输出的数据类型固定为float32,维度为(batch, n)。
x |
w |
group_list |
scale |
bias |
pertoken_scale |
shared_input |
logit |
row_index |
y |
---|---|---|---|---|---|---|---|---|---|
int8 |
int8 |
int64 |
float32 |
None |
float32 |
bfloat16 |
float32 |
int64 |
float32 |
int8 |
int8 |
int64 |
float32 |
None |
float32 |
None |
None |
int64 |
float32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | import numpy as np import torch import torch_npu import tensorflow as tf from scipy.special import softmax bfloat16 = tf.bfloat16.as_numpy_dtype m, k, n = 576, 2048, 7168 batch = 72 topK = 8 group_num = 8 x = np.random.randint(-10, 10, (m, k)).astype(np.int8) weight = np.random.randint(-10, 10, (group_num, k, n)).astype(np.int8) scale = np.random.normal(0, 0.01, (group_num, n)).astype(np.float32) pertoken_scale = np.random.normal(0, 0.01, (m, 1)).astype(np.float32) group_list = np.array([batch] * group_num, dtype=np.int64) shared_input = np.random.normal(0, 0.1, (batch // 4, n)).astype(np.float32) logit_ori = np.random.normal(0, 0.1, (batch, group_num)).astype(np.float32) routing = np.argsort(logit_ori, axis=1)[:, -topK:].astype(np.int32) logit = softmax(logit_ori[np.arange(batch).reshape(-1, 1).repeat(topK, axis=1), routing], axis=1).astype(np.float32) logit = logit.reshape(m) row_index = (np.argsort(routing.reshape(-1)) // topK).astype(np.int64) x_clone = torch.from_numpy(x).npu() weight_clone = torch.from_numpy(weight).npu() weightNz = torch_npu.npu_format_cast(weight_clone, 29) scale_clone = torch.from_numpy(scale).npu() pertoken_scale_clone = torch.from_numpy(pertoken_scale).npu() group_list_clone = torch.from_numpy(group_list).npu() shared_input_clone = torch.from_numpy(shared_input).to(torch.bfloat16).npu() logit_clone = torch.from_numpy(logit).npu() row_index_clone = torch.from_numpy(row_index).npu() shared_input_offset = batch // 2 output_bs = batch y = torch_npu.npu_grouped_matmul_finalize_routing(x_clone, weightNz, group_list_clone, scale=scale_clone, pertoken_scale=pertoken_scale_clone, shared_input=shared_input_clone, logit=logit_clone, row_index=row_index_clone, shared_input_offset=shared_input_offset, output_bs=output_bs) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import numpy as np import torch import torch_npu import torchair as tng import tensorflow as tf from scipy.special import softmax from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, weight, group_list, scale, pertoken_scale, shared_input, logit, row_index, shared_input_offset, output_bs): output = torch_npu.npu_grouped_matmul_finalize_routing(x, weight, group_list, scale=scale, pertoken_scale=pertoken_scale, shared_input=shared_input, logit=logit, row_index=row_index, shared_input_offset=shared_input_offset, output_bs=output_bs) return output bfloat16 = tf.bfloat16.as_numpy_dtype m, k, n = 576, 2048, 7168 batch = 72 topK = 8 group_num = 8 x = np.random.randint(-10, 10, (m, k)).astype(np.int8) weight = np.random.randint(-10, 10, (group_num, k, n)).astype(np.int8) scale = np.random.normal(0, 0.01, (group_num, n)).astype(np.float32) pertoken_scale = np.random.normal(0, 0.01, (m, 1)).astype(np.float32) group_list = np.array([batch] * group_num, dtype=np.int64) shared_input = np.random.normal(0, 0.1, (batch // 4, n)).astype(np.float32) logit_ori = np.random.normal(0, 0.1, (batch, group_num)).astype(np.float32) routing = np.argsort(logit_ori, axis=1)[:, -topK:].astype(np.int32) logit = softmax(logit_ori[np.arange(batch).reshape(-1, 1).repeat(topK, axis=1), routing], axis=1).astype(np.float32) logit = logit.reshape(m) row_index = (np.argsort(routing.reshape(-1)) // topK).astype(np.int64) x_clone = torch.from_numpy(x).npu() weight_clone = torch.from_numpy(weight).npu() weightNz = torch_npu.npu_format_cast(weight_clone, 29) scale_clone = torch.from_numpy(scale).npu() pertoken_scale_clone = torch.from_numpy(pertoken_scale).npu() group_list_clone = torch.from_numpy(group_list).npu() shared_input_clone = torch.from_numpy(shared_input).to(torch.bfloat16).npu() logit_clone = torch.from_numpy(logit).npu() row_index_clone = torch.from_numpy(row_index).npu() shared_input_offset = batch // 2 output_bs = batch model = Model().npu() model = torch.compile(model, backend=npu_backend, dynamic=False) y = model(x_clone, weightNz, group_list_clone, scale_clone, pertoken_scale_clone, shared_input_clone, logit_clone, row_index_clone, shared_input_offset, output_bs) |