Swish
Function Usage
In neural networks, Swish is an important activation function. The specific calculation formula is as follows, where β represents a constant and PAR represents the number of elements that can be processed by the vector unit in one iteration:


Prototype
1 2 | template <typename T, bool isReuseSource = false> __aicore__ inline void Swish(const LocalTensor<T> &dstLocal, const LocalTensor<T> &srcLocal, uint32_t dataSize, const T &scalarValue) |
Parameters
Parameter |
Description |
|---|---|
T |
Data type of the operand. |
isReuseSource |
Whether the source operand can be modified. This parameter is reserved. Pass the default value false. |
Parameter |
Input/Output |
Description |
|---|---|---|
dstLocal |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. |
srcLocal |
Input |
Source operand. The source operand must have the same data type as the destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. |
dataSize |
Input |
Number of actually computed data elements. Value range: dataSize ∈ [0, min(srcLocal.GetSize(), dstLocal.GetSize())]. |
scalarValue |
Input |
β in the activation function. The supported data types are half and float. The data type of β must be the same as that of the source and destination operands. |
Returns
None
Availability
Constraints
- For details about the alignment requirements of the operand address offset, see General Restrictions.
- The source operand address must not overlap the destination operand address.
- Currently, only the ND format is supported.
- Ensure that dataSize is less than or equal to the element range stored in srcTensor and dstTensor.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | #include "kernel_operator.h" template <typename srcType> class KernelSwish { public: __aicore__ inline KernelSwish() {} __aicore__ inline void Init(GM_ADDR srcGm, GM_ADDR dstGm, uint32_t inputSize, srcType scalar) { dataSize = inputSize; scalarValue = scalar; srcGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ srcType *>(srcGm), dataSize); dstGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ srcType *>(dstGm), dataSize); pipe.InitBuffer(inQueueX, 1, dataSize * sizeof(srcType)); pipe.InitBuffer(outQueue, 1, dataSize * sizeof(srcType)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<srcType> srcLocal = inQueueX.AllocTensor<srcType>(); AscendC::DataCopy(srcLocal, srcGlobal, dataSize); inQueueX.EnQue(srcLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<srcType> dstLocal = outQueue.AllocTensor<srcType>(); AscendC::LocalTensor<srcType> srcLocal = inQueueX.DeQue<srcType>(); Swish(dstLocal, srcLocal, dataSize, scalarValue); outQueue.EnQue<srcType>(dstLocal); inQueueX.FreeTensor(srcLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<srcType> dstLocal = outQueue.DeQue<srcType>(); AscendC::DataCopy(dstGlobal, dstLocal, dataSize); outQueue.FreeTensor(dstLocal); } private: AscendC::GlobalTensor<srcType> srcGlobal; AscendC::GlobalTensor<srcType> dstGlobal; AscendC::TPipe pipe; AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueX; AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueue; uint32_t dataSize = 0; srcType scalarValue = 0; }; template <typename dataType> __aicore__ void kernel_Swish_operator(GM_ADDR srcGm, GM_ADDR dstGm, uint32_t dataSize) { KernelSwish<dataType> op; dataType scalarValue = 1.702; op.Init(srcGm, dstGm, dataSize, scalarValue); op.Process(); } |
1 2 3 4 5 6 7 8 9 10 11 | Input data (srcLocal): [ 0.5312 -3.654 -2.92 3.787 -3.059 3.77 0.571 -0.668 -0.09534 0.5454 -1.801 -1.791 1.563 0.878 3.973 1.799 2.023 1.018 3.082 -3.814 2.254 -3.717 0.4675 -0.4631 -2.47 0.9814 -0.854 3.31 3.256 3.764 1.867 -1.773] Output data (dstLocal): [ 0.3784 -0.007263 -0.02016 3.78 -0.01666 3.762 0.414 -0.1622 -0.04382 0.3909 -0.0803 -0.08105 1.461 0.717 3.969 1.719 1.96 0.8647 3.066 -0.00577 2.207 -0.006626 0.3223 -0.1448 -0.03622 0.8257 -0.1617 3.297 3.244 3.756 1.792 -0.0825] |