Scatter (ISASI)
Function Usage
Generates a new result tensor based on a given continuous input tensor, a destination address offset tensor, and the offset address, and distributes the input tensor to the result tensor.
Scatters elements in the source operand src to positions (specified by using dst_offset and base_addr) in the destination operand dst.
Prototype
- Computation of the first n data elements of a tensor
1 2
template <typename T> __aicore__ inline void Scatter(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& dstOffsetLocal, const uint32_t dstBaseAddr, const uint32_t count)
- High-dimensional tensor sharding computation
- Bitwise mask mode
1 2
template <typename T> __aicore__ inline void Scatter(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& dstOffsetLocal, const uint32_t dstBaseAddr, const uint64_t mask[], const uint8_t repeatTimes, const uint8_t srcRepStride)
- Contiguous mask mode
1 2
template <typename T> __aicore__ inline void Scatter(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& dstOffsetLocal, const uint32_t dstBaseAddr, const uint64_t mask, const uint8_t repeatTimes, const uint8_t srcRepStride)
- Bitwise mask mode
Parameters
|
Parameter |
Description |
|---|---|
|
T |
Operand data type. For the For the |
Availability
Constraints
- For details about the operand address alignment requirements, see General Address Alignment Restrictions.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
#include "kernel_operator.h" template <typename T> class ScatterTest { public: __aicore__ inline ScatterTest() {} __aicore__ inline void Init(__gm__ uint8_t* dstGm, __gm__ uint8_t* srcGm, __gm__ uint8_t* dstOffsetGm, const uint32_t count) { m_elementCount = count; m_dstGlobal.SetGlobalBuffer((__gm__ T*)dstGm); m_srcGlobal.SetGlobalBuffer((__gm__ T*)srcGm); m_dstOffsetGlobal.SetGlobalBuffer((__gm__ uint32_t*)dstOffsetGm); m_pipe.InitBuffer(m_queIn, 2, m_elementCount * sizeof(uint32_t)); m_pipe.InitBuffer(m_queOut, 1, m_elementCount * sizeof(uint32_t)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<T> srcLocal = m_queIn.AllocTensor<T>(); AscendC::DataCopy(srcLocal, m_srcGlobal, m_elementCount); m_queIn.EnQue(srcLocal); AscendC::LocalTensor<uint32_t> dstOffsetLocal = m_queIn.AllocTensor<uint32_t>(); AscendC::DataCopy(dstOffsetLocal, m_dstOffsetGlobal, m_elementCount); m_queIn.EnQue(dstOffsetLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<T> srcLocal = m_queIn.DeQue<T>(); AscendC::LocalTensor<uint32_t> dstOffsetLocal = m_queIn.DeQue<uint32_t>(); AscendC::LocalTensor<T> dstLocal = m_queOut.AllocTensor<T>(); dstLocal.SetSize(m_elementCount); AscendC::Scatter(dstLocal, srcLocal, dstOffsetLocal, (uint32_t)0, m_elementCount); m_queIn.FreeTensor(srcLocal); m_queIn.FreeTensor(dstOffsetLocal); m_queOut.EnQue(dstLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<T> dstLocal = m_queOut.DeQue<T>(); AscendC::DataCopy(m_dstGlobal, dstLocal, m_elementCount); m_queOut.FreeTensor(dstLocal); } private: AscendC::TPipe m_pipe; AscendC::TQue<AscendC::QuePosition::VECIN, 1> m_queCalc; AscendC::GlobalTensor<T> m_valueGlobal; uint32_t m_concatRepeatTimes; uint32_t m_sortRepeatTimes; uint32_t m_extractRepeatTimes; uint32_t m_elementCount; AscendC::GlobalTensor<uint32_t> m_dstOffsetGlobal; AscendC::GlobalTensor<T> m_srcGlobal; AscendC::GlobalTensor<T> m_dstGlobal; AscendC::TQue<AscendC::QuePosition::VECIN, 2> m_queIn; AscendC::TQue<AscendC::QuePosition::VECOUT, 1> m_queOut; }; // class ScatterTest #define KERNEL_SCATTER(T, count) \ extern "C" __global__ __aicore__ void kernel_scatter_##T##_##count(GM_ADDR dstGm, GM_ADDR srcGm,\ GM_ADDR dstOffsetGm) \ { \ ScatterTest<T> op; \ op.Init(dstGm, srcGm, dstOffsetGm, count); \ op.Process(); \ } |
Input dstOffsetLocal: [254 252 250 ... 4 2 0] Input srcLocal (128 data elements of the half type): [0 1 2 ... 125 126 127] Output dstGlobal: [127 126 125 ... 2 1 0]