昇腾社区首页
中文
注册

matmul支持pertoken量化模式

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品 / Atlas A3 训练系列产品

Atlas A2 训练系列产品 / Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

linear operation支持pertoken量化模式

计算公式

参数配置

成员名称

取值范围

transposeA

false/true

transposeB

false/true

hasBias

false

outDataType

ACL_FLOAT16/ACL_BF16

enAccum

false

matmulType

MATMUL_UNDEFINED

quantMode

PER_TOKEN

输入

参数

维度

数据类型

格式

描述

x

[m,k]/[batch,m,k]

int8

ND

矩阵乘的A矩阵。

weight

[k,n]/[batch,k,n]

int8

ND

矩阵乘的B矩阵,权重。

deqScale

[n]

float

ND

反量化步长。

perTokenScale

[m]

float

ND

perToken反量化步长。

输出

参数

维度

数据类型

格式

描述

output

[m, n]/[batch, m, n]

float16/bf16

ND

矩阵乘反量化计算结果。

规格说明

由于输入输出的排列组合约束较复杂,下图列举了所有输入输出属性的组合,图中没有的组合即不支持:

图1 输入输出属性排列组合

OP使用与典型场景

OP使用时,可参考算子使用指导中的使用流程部分,其中,单算子构造Operation参数的构造方法参考以下参数构造部分。

// 参数构造
atb::infer::LinearParam param;
param.transposeA = false;
param.transposeB = false;
param.hasBias = false;
param.outDataType = ACL_FLOAT16;
param.enAccum = false;
param.matmulType = MATMUL_UNDEFINED;
param.quantMode = PER_TOKEN;
# 计算示例
>>> x
tensor([[1, 2],
        [3, 4]])
>>> weight
tensor([[1, 2, 3],
        [4, 5, 6]])
>>> deqScale
tensor([1, 2, 3])
>>> perTokenScale
tensor([1, 2])
>>> output
tensor([[9, 24, 45],
        [38, 104, 198]])
# 9 = (1 * 1 + 2 * 4) * 1 * 1
# 24 = (1 * 2 + 2 * 5) * 2 * 1
# 45 = (1 * 3 + 2 * 6) * 3 * 1
# 38 = (3 * 1 + 4 * 4) * 1 * 2
# 104 = (3 * 2 + 4 * 5) * 2 * 2
# 198 = (3 * 3 + 4 * 6) * 3 * 2