LogSoftMax
产品支持情况
|
产品 |
是否支持 |
|---|---|
|
Atlas 350 加速卡 |
√ |
|
|
√ |
|
|
√ |
|
|
x |
|
|
√ |
|
|
x |
|
|
x |
功能说明
对输入tensor做LogSoftmax计算。计算公式如下 :


为方便理解,通过Python脚本实现的方式表达计算公式如下,其中src是源操作数(输入),dst、sum、max为目的操作数(输出)。
1 2 3 4 5 6 7 8 9 10 |
def log_softmax(src): #基于last轴进行rowmax(按行取最大值)处理 max = np.max(src, axis=-1, keepdims=True) sub = src - max exp = np.exp(sub) #基于last轴进行rowsum(按行求和)处理 sum = np.sum(exp, axis=-1, keepdims=True) dst = exp / sum dst = np.log10(dst) return dst, max, sum |
实现原理
以float类型,ND格式,shape为[m, k]的输入Tensor为例,描述LogSoftMax高阶API内部算法框图,如下图所示。
计算过程分为如下几步,均在Vector上进行:
- reducemax步骤:对输入x的每一行数据求最大值得到[m, 1],计算结果会保存到一个临时空间temp中;
- broadcast步骤:对temp中的数据([m, 1])做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],同时输出max;
- sub步骤:对输入x的所有数据按行减去max;
- exp步骤:对sub之后的所有数据求exp;
- reducesum步骤:对exp后的结果的每一行数据求和得到[m, 1],计算结果会保存到临时空间temp中;
- broadcast步骤:对temp([m, 1])做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],同时输出sum;
- div步骤:对exp结果的所有数据按行除以sum;
- log步骤:对div后的所有数据按行做log10计算,输出y。
函数原型
1 2 |
template <typename T, bool isReuseSource = false, bool isDataFormatNZ = false> __aicore__ inline void LogSoftMax(const LocalTensor<T>& dst, const LocalTensor<T>& sum, const LocalTensor<T>& max, const LocalTensor<T>& src, const LocalTensor<uint8_t>& sharedTmpBuffer, const LogSoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {}) |
由于该接口的内部实现中涉及复杂的数学计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持开发者通过sharedTmpBuffer入参传入。临时空间大小BufferSize的获取方式如下:通过LogSoftMax Tiling中提供的接口获取空间范围的大小。
参数说明
|
参数名 |
描述 |
|---|---|
|
T |
操作数的数据类型。 Atlas 350 加速卡,支持的数据类型为:half、float。 |
|
isReuseSource |
是否允许修改源操作数。该参数预留,传入默认值false即可。 |
|
isDataFormatNZ |
源操作数是否为NZ格式。默认值为false。 |
|
参数名 |
输入/输出 |
描述 |
||
|---|---|---|---|---|
|
dst |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 last轴长度需要32Byte对齐。 |
||
|
sum |
输出 |
reduceSum操作数。 reduceSum操作数的数据类型需要与目的操作数保持一致。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
|
||
|
max |
输出 |
reduceMax操作数。 reduceMax操作数的数据类型需要与目的操作数保持一致。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
|
||
|
src |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 源操作数的数据类型需要与目的操作数保持一致。 |
||
|
sharedTmpBuffer |
输入 |
临时缓存。临时空间大小BufferSize的获取方式请参考LogSoftMax Tiling。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 |
||
|
tiling |
输入 |
LogSoftMax计算所需Tiling信息,Tiling信息的获取请参考LogSoftMax Tiling。 |
||
|
softmaxShapeInfo |
输入 |
src的shape信息。SoftMaxShapeInfo类型,具体定义如下:
注意,当输入输出的数据格式为NZ(FRACTAL_NZ)格式时,尾轴长度为reduce轴长度,即图2中的W0*W1,非尾轴为H0*H1。 |
返回值说明
无
约束说明
- 输入源数据需保持值域在[-2147483647.0, 2147483647.0]。若输入不在范围内,输出结果无效。
- 不支持源操作数与目的操作数地址重叠。
- 不支持sharedTmpBuffer与源操作数和目的操作数地址重叠。
- 操作数地址对齐要求请参见通用地址对齐约束。
- 当参数softmaxShapeInfo中srcM != oriSrcM 或者 srcK != oriSrcK时,开发者需要对GM上的原始输入(oriSrcM, oriSrcK)在M或K方向补齐数据到(srcM, srcK),补齐的数据会参与部分运算,在输入输出复用的场景下,API的计算结果会覆盖srcTensor中补齐的原始数据,在输入输出不复用的场景下,API的计算结果会覆盖dstTensor中对应srcTensor补齐位置的数据。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
// dstLocal: 存放LogSoftMax计算结果的Tensor // sumLocal:存放LogSoftMax计算过程中reducesum结果的Tensor // maxLocal:存放LogSoftMax计算过程中reduceMax结果的Tensor // srcLocal:存放LogSoftMax计算的输入Tensor // sharedTmpBuffer: 存放LogSoftMax计算过程中临时缓存的Tensor // softmaxTiling:存放LogSoftMax计算所需Tiling信息,可通过LogSoftMaxTilingFunc接口获取 AscendC::SoftMaxShapeInfo softmaxInfo( /* 非尾轴长度的乘积 */ srcM, /* 尾轴长度,必须32Bytes对齐 */ srcK, /* 原始非尾轴长度的乘积 */ oriSrcM, /* 原始尾轴长度 */ oriSrcK ); AscendC::LogSoftMax<DTYPE_X, false>(dstLocal, sumLocal, maxLocal, srcLocal, sharedTmpBuffer, softmaxTiling, softmaxInfo); |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
输入数据(srcLocal): [[-100. -80. -60. -50. -30. -20. -15. -10. ] [ -9. -8. -7. -6. -5. -4. -3. -2. ] [ -1.5 -1. -0.8 -0.6 -0.5 -0.45 -0.4 -0.35 ] [ -0.3 -0.25 -0.2 -0.15 -0.1 -0.05 -0.01 -0.001] [ 0. 0.001 0.01 0.05 0.1 0.15 0.2 0.25 ] [ 0.3 0.35 0.4 0.45 0.5 0.6 0.8 1. ] [ 1.5 2. 3. 4. 5. 6. 7. 8. ] [ 9. 10. 15. 20. 30. 50. 60. 80. ]] 输出数据(maxLocal): [[-10. -10. -10. -10. -10. -10. -10. -10. ] [ -2. -2. -2. -2. -2. -2. -2. -2. ] [ -0.35 -0.35 -0.35 -0.35 -0.35 -0.35 -0.35 -0.35 ] [ -0.001 -0.001 -0.001 -0.001 -0.001 -0.001 -0.001 -0.001] [ 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 ] [ 1. 1. 1. 1. 1. 1. 1. 1. ] [ 8. 8. 8. 8. 8. 8. 8. 8. ] [ 80. 80. 80. 80. 80. 80. 80. 80. ]] 输出数据(sumLocal): [[1.0067834 1.0067834 1.0067834 1.0067834 1.0067834 1.0067834 1.0067834 1.0067834] [1.5814459 1.5814459 1.5814459 1.5814459 1.5814459 1.5814459 1.5814459 1.5814459] [5.971886 5.971886 5.971886 5.971886 5.971886 5.971886 5.971886 5.971886 ] [7.051223 7.051223 7.051223 7.051223 7.051223 7.051223 7.051223 7.051223 ] [6.880514 6.880514 6.880514 6.880514 6.880514 6.880514 6.880514 6.880514 ] [5.239974 5.239974 5.239974 5.239974 5.239974 5.239974 5.239974 5.239974 ] [1.5820376 1.5820376 1.5820376 1.5820376 1.5820376 1.5820376 1.5820376 1.5820376] [1. 1. 1. 1. 1. 1. 1. 1. ]] 输出数据(dstLocal): [[-39.08944 -30.40355 -21.71766 -17.374716 -8.688826 -4.345881 -2.1744084 -0.00293603] [ -3.2391157 -2.8048213 -2.3705268 -1.9362322 -1.5019379 -1.0676433 -0.6333489 -0.19905435] [ -1.2755501 -1.0584029 -0.971544 -0.88468516 -0.8412557 -0.8195409 -0.7978263 -0.77611154] [ -0.97811854 -0.9564038 -0.93468904 -0.9129743 -0.89125955 -0.86954486 -0.85217315 -0.84826446] [ -0.9461945 -0.94576025 -0.94185156 -0.92447984 -0.90276515 -0.88105035 -0.8593356 -0.8376209 ] [ -1.0233353 -1.0016205 -0.97990584 -0.95819116 -0.93647635 -0.89304686 -0.806188 -0.7193291 ] [ -3.022131 -2.8049836 -2.3706892 -1.9363947 -1.5021002 -1.0678058 -0.6335113 -0.1992168 ] [-30.834908 -30.400614 -28.229141 -26.057669 -21.714724 -13.028834 -8.685889 0. ]] |