Gemm(废弃)
产品支持情况
产品 |
是否支持 |
|---|---|
x |
|
x |
|
x |
|
√ |
|
x |
|
√ |
功能说明
该接口废弃,并将在后续版本移除,请不要使用该接口。
根据输入的切分规则,将给定的两个输入张量做矩阵乘,输出至结果张量。将A和B两个输入矩阵乘法在一起,得到一个输出矩阵C。
函数原型
- 功能接口:
1 2
template <typename T, typename U, typename S> __aicore__ inline void Gemm(const LocalTensor<T>& dst, const LocalTensor<U>& src0, const LocalTensor<S>& src1, const uint32_t m, const uint32_t k, const uint32_t n, GemmTiling tilling, bool partialsum = true, int32_t initValue = 0)
- 切分方案计算接口:
1 2
template <typename T> __aicore__ inline GemmTiling GetGemmTiling(uint32_t m, uint32_t k, uint32_t n)
参数说明
参数名称 |
类型 |
说明 |
||
|---|---|---|---|---|
dst |
输出 |
目的操作数。 |
||
src0 |
输入 |
源操作数,TPosition为A1。 |
||
src1 |
输入 |
源操作数,TPosition为B1。 |
||
m |
输入 |
左矩阵Src0Local有效Height,范围:[1, 4096]。 注意:m可以不是16的倍数。 |
||
k |
输入 |
左矩阵Src0Local有效Width、右矩阵Src1Local有效Height。
注意:k可以不是16的倍数。 |
||
n |
输入 |
右矩阵Src1Local有效Width,范围:[1, 4096]。 注意:n可以不是16的倍数。 |
||
tilling |
输入 |
切分规则,类型为GemmTiling,结构体具体定义为:
参数说明请参考表3。 |
||
partialsum |
输入 |
当dst参数所在的TPosition为CO2时,通过该参数控制计算结果是否搬出。
|
||
initValue |
输入 |
表示dst是否需要初始化。
|
src0.dtype |
src1.dtype |
dst.dtype |
|---|---|---|
int8_t |
int8_t |
int32_t |
half |
half |
float |
half |
half |
half |
参数名称 |
类型 |
说明 |
||
|---|---|---|---|---|
blockSize |
uint32_t |
固定值,恒为16,一个维度内存放的元素个数。 |
||
loopMode |
LoopMode |
遍历模式,结构体具体定义为:
|
||
mNum |
uint32_t |
M轴等效数据长度参数值,范围:[1, 4096]。 |
||
nNum |
uint32_t |
N轴等效数据长度参数值,范围:[1, 4096]。 |
||
kNum |
uint32_t |
K轴等效数据长度参数值。
|
||
roundM |
uint32_t |
M轴等效数据长度参数值且以blockSize为倍数向上取整,范围:[1, 4096] |
||
roundN |
uint32_t |
N轴等效数据长度参数值且以blockSize为倍数向上取整,范围:[1, 4096] |
||
roundK |
uint32_t |
K轴等效数据长度参数值且以c0Size为倍数向上取整。
|
||
c0Size |
uint32_t |
一个block的字节长度,范围:[16或者32]。 |
||
dtypeSize |
uint32_t |
传入的数据类型的字节长度,范围:[1, 2]。 |
||
mBlockNum |
uint32_t |
M轴Block个数,mBlockNum = mNum / blockSize。 |
||
nBlockNum |
uint32_t |
N轴Block个数,nBlockNum = nNum / blockSize。 |
||
kBlockNum |
uint32_t |
K轴Block个数,kBlockNum = kNum / blockSize。 |
||
mIterNum |
uint32_t |
遍历维度数量,范围:[1, 4096]。 |
||
nIterNum |
uint32_t |
遍历维度数量,范围:[1, 4096]。 |
||
kIterNum |
uint32_t |
遍历维度数量,范围:[1, 4096]。 |
||
mTileBlock |
uint32_t |
M轴切分块个数,范围:[1, 4096]。 |
||
nTileBlock |
uint32_t |
N轴切分块个数,范围:[1, 4096]。 |
||
kTileBlock |
uint32_t |
K轴切分块个数,范围:[1, 4096]。 |
||
kTailBlock |
uint32_t |
K轴尾块个数,范围:[1, 4096]。 |
||
mTailBlock |
uint32_t |
M轴尾块个数,范围:[1, 4096]。 |
||
nTailBlock |
uint32_t |
N轴尾块个数,范围:[1, 4096]。 |
||
kHasTail |
bool |
K轴是否存在尾块。 |
||
mHasTail |
bool |
M轴是否存在尾块。 |
||
nHasTail |
bool |
N轴是否存在尾块。 |
||
kHasTailEle |
bool |
是否存在尾块元素。 |
||
kTailEle |
uint32_t |
K轴尾块元素,范围:[1, 4096]。 |
约束说明
- 参数m,k,n可以不是16对齐,但因硬件原因,操作数dst,Src0Local和Src1Local的shape需满足对齐要求,即m方向,n方向要求向上16对齐,k方向根据操作数数据类型按16或32向上对齐。
- 操作数地址对齐要求请参见通用地址对齐约束。