AdjustSoftMaxRes
产品支持情况
产品 |
是否支持 |
|---|---|
Atlas 350 加速卡 |
√ |
√ |
|
√ |
|
x |
|
√ |
|
x |
|
x |
功能说明
本接口用于调整SoftMax的计算结果为指定的值。主要用于对SoftMax相关计算结果做后处理。当输入的max中存在指定的值的时候,会调整对应的softmaxres中的结果为输入的自定义的值。以上调整方式为按行进行,即当max某一行的值为某个值时,调整当前softmaxres对应一行的值都为输入的值。
为方便理解,通过Python脚本实现的方式,表达其计算公式如下,其中res是输入也是输出,max\from\to\res_shape都为输入。
1 2 3 4 5 6 | def adjust_softmax_res(res, max, from, to, res_shape): for i in res_shape[0]: if max[i] == from: for j in res_shape[1]: res[i][j] = to return |
函数原型
1 2 | template <typename T1, typename T2, bool isDataFormatNZ = false, uint8_t stepSizeMode = 0> __aicore__ inline bool AdjustSoftMaxRes(const LocalTensor<T1>& softMaxRes, const LocalTensor<T2>& maxTensor, const uint32_t from, const T1 to, const SoftMaxShapeInfo& softmaxShapeInfo) |
参数说明
参数名 |
描述 |
|---|---|
T1 |
softMaxRes的数据类型。 Atlas 350 加速卡,支持的数据类型为:half、float。 |
T2 |
maxTensor的数据类型。 Atlas 350 加速卡,支持的数据类型为:half、float。 |
isDataFormatNZ |
当前输入输出的数据格式是否为NZ格式,默认数据格式为ND,即默认取值为false。 |
stepSizeMode |
maxTensor取元素的步进长度的模式。参数取值如下:
|
参数名 |
输入/输出 |
描述 |
||
|---|---|---|---|---|
softMaxRes |
输入/输出 |
既是源操作数也是目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor数据结构的定义请参考LocalTensor last轴长度需要32Byte对齐。 一般为softmax计算的输出结果。 |
||
maxTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 softmax计算过程中reducemax的结果。
|
||
from |
输入 |
源操作数,类型为uint32_t。 需要判断的maxTensor中的值。需要注意的是,由于maxTensor中的值均为浮点数类型,因此此处需要填入的值为浮点数类型对应十六进制的值。比如当需要判断maxTensor是否有1.0这个值时,from值需要填入1.0对应的十六进制值0x3f800000。 |
||
to |
输入 |
源操作数,类型和softMaxRes的数据类型保持一致。 需要往softMaxRes中填充的值。 |
||
softmaxShapeInfo |
输入 |
softMaxRes的shape信息,结构定义如下:
需要注意,目前仅支持ND输入。 |
返回值说明
bool类型,当返回true时,表示maxTensor中存在需要判断的值,若返回false,则表示maxTensor中不存在需要判断的值。
约束说明
- 操作数地址对齐要求请参见通用地址对齐约束。
- 当参数softmaxShapeInfo中srcM != oriSrcM 或者 srcK != oriSrcK时,开发者需要对GM上的原始输入(oriSrcM, oriSrcK)在M或K方向补齐数据到(srcM, srcK),补齐的数据会参与部分运算,在输入输出复用的场景下,API的计算结果会覆盖srcTensor中补齐的原始数据,在输入输出不复用的场景下,API的计算结果会覆盖dstTensor中对应srcTensor补齐位置的数据。
调用示例
本样例中需要对SoftMax计算结果做后处理,判断maxTensor中是否存在0xFF7FFFFF,如果存在刷新对应结果为0。本样例中实现的是固定shape为输入x[32, 8],输出y[32, 32]的AdjustSoftMaxResCustom算子。输入softMaxRes的shape大小为[32, 8],maxTensor的shape大小为[32,8],数据类型均为float。
1 2 3 4 5 6 7 8 9 | // srcLocal:softmax计算结果 AscendC::SoftMax(srcLocal, ...) // maxLocal:softmax中间结果,reducemax的结果 // FROM: 判断maxLocal中是否存在值等于FROM的元素 // TO: 当maxLocal中存在值等于FROM的元素时,srcLocal中对应行的元素将被替换为TO // srcShape:描述srcLocal的shape信息 AscendC::SoftMaxShapeInfo srcShape = {height, width, height, width}; AscendC::AdjustSoftMaxRes<float, float>(srcLocal, maxLocal, FROM, TO, srcShape); |
结果示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | 输入数据(srcLocal): [ 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 ] 输入数据(maxLocal): [ 7. 7. 7. 7. 7. 7. 7. 7. 15. 15. 15. 15. 15. 15. 15. 15. 23. 23. 23. 23. 23. 23. 23. 23. 31. 31. 31. 31. 31. 31. 31. 31. 39. 39. 39. 39. 39. 39. 39. 39. 47. 47. 47. 47. 47. 47. 47. 47. 55. 55. 55. 55. 55. 55. 55. 55. 63. 63. 63. 63. 63. 63. 63. 63. 71. 71. 71. 71. 71. 71. 71. 71. 79. 79. 79. 79. 79. 79. 79. 79. 87. 87. 87. 87. 87. 87. 87. 87. 95. 95. 95. 95. 95. 95. 95. 95. 103. 103. 103. 103. 103. 103. 103. 103. 111. 111. 111. 111. 111. 111. 111. 111. 119. 119. 119. 119. 119. 119. 119. 119. 127. 127. 127. 127. 127. 127. 127. 127. 135. 135. 135. 135. 135. 135. 135. 135. 143. 143. 143. 143. 143. 143. 143. 143. 151. 151. 151. 151. 151. 151. 151. 151. 159. 159. 159. 159. 159. 159. 159. 159. 167. 167. 167. 167. 167. 167. 167. 167. 175. 175. 175. 175. 175. 175. 175. 175. 183. 183. 183. 183. 183. 183. 183. 183. 191. 191. 191. 191. 191. 191. 191. 191. 199. 199. 199. 199. 199. 199. 199. 199. 207. 207. 207. 207. 207. 207. 207. 207. 215. 215. 215. 215. 215. 215. 215. 215. 223. 223. 223. 223. 223. 223. 223. 223. 231. 231. 231. 231. 231. 231. 231. 231. 239. 239. 239. 239. 239. 239. 239. 239. 247. 247. 247. 247. 247. 247. 247. 247. 255. 255. 255. 255. 255. 255. 255. 255. ] 输出数据(srcLocal): [ 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 0.0005766128 0.001567396 0.004260624 0.011581578 0.031481992 0.08557693 0.23262219 0.63233274 ] |