Gather
Function Usage
Gathers given input tensors by element to the result tensor based on the offset address tensor provided.
Prototype
- Computation of the first n data elements of a tensor
1 2
template <typename T> __aicore__ inline void Gather(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& srcOffsetLocal, const uint32_t srcBaseAddr, const uint32_t count)
- High-dimensional tensor sharding computation
- Bitwise mask mode
1 2
template <typename T> __aicore__ inline void Gather(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& srcOffsetLocal, const uint32_t srcBaseAddr, const uint64_t mask[], const uint8_t repeatTimes, const uint16_t dstRepStride)
- Contiguous mask mode
1 2
template <typename T> __aicore__ inline void Gather(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& srcOffsetLocal, const uint32_t srcBaseAddr, const uint64_t mask, const uint8_t repeatTimes, const uint16_t dstRepStride)
- Bitwise mask mode
Parameters
|
Parameter |
Description |
|---|---|
|
T |
Data type of the operand. |
|
Parameter |
Input/Output |
Meaning |
|---|---|---|
|
dstLocal |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The start address of the LocalTensor must be 32-byte aligned. |
|
srcLocal |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The start address of the LocalTensor must be 32-byte aligned. The data type is the same as that of dstLocal. |
|
srcOffsetLocal |
Input |
Address offset of each element in src. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The start address of the LocalTensor must be 32-byte aligned. The offset is relative to the start base address of src, in bytes. The address offset must be greater than or equal to 0. The value must ensure that the bit width of the src element type is aligned. Otherwise, unexpected behavior occurs. In addition, ensure that the offset address does not exceed the range of the UB-size data. |
|
srcBaseAddr |
Input |
Start base address of srcLocal, in bytes. Ensure that the bit width of the src element type is aligned. Otherwise, unexpected behavior occurs. |
|
count |
Input |
How many pieces of data to be processed. The value cannot exceed the number of srcLocal and srcOffsetLocal elements. |
|
mask |
Input |
mask is used to control the elements that participate in computation in each iteration.
|
|
repeatTimes |
Input |
Number of instruction iterations. Data of eight data blocks (32 bytes) is collected in each iteration. Data range: repeatTimes ∈ [0,255] |
|
dstRepStride |
Input |
Address stride of the operand between adjacent iterations. The unit is data block (32 bytes). |
Availability
Constraints
- For details about the alignment requirements of the operand address offset, see General Restrictions.
- To save memory space, you can define a tensor shared by the source and destination operands (by address overlapping). The general instruction restrictions are as follows.
- For a single repeat (repeatTimes = 1), the source operand must completely overlap the destination operand.
- For multiple repeats (repeatTimes > 1), there cannot be any dependency between the source operand and the destination operand in iterations. For example, if the destination operand in the Nth iteration is the source operand in the (N+1)th iteration, address overlapping is not supported.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
#include "kernel_operator.h" template <typename T> class GatherTest { public: __aicore__ inline GatherTest() {} __aicore__ inline void Init(__gm__ uint8_t* dstGm, __gm__ uint8_t* srcGm, __gm__ uint8_t* srcOffsetGm, const uint32_t count) { m_elementCount = count; m_dstGlobal.SetGlobalBuffer((__gm__ T*)dstGm); m_srcGlobal.SetGlobalBuffer((__gm__ T*)srcGm); m_srcOffsetGlobal.SetGlobalBuffer((__gm__ uint32_t*)srcOffsetGm); m_pipe.InitBuffer(m_queIn, 2, m_elementCount * sizeof(uint32_t)); m_pipe.InitBuffer(m_queOut, 2, m_elementCount * sizeof(uint32_t)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<T> srcLocal = m_queIn.AllocTensor<T>(); AscendC::DataCopy(srcLocal, m_srcGlobal, m_elementCount); m_queIn.EnQue(srcLocal); AscendC::LocalTensor<uint32_t> srcOffsetLocal = m_queIn.AllocTensor<uint32_t>(); AscendC::DataCopy(srcOffsetLocal, m_srcOffsetGlobal, m_elementCount); m_queIn.EnQue(srcOffsetLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<T> srcLocal = m_queIn.DeQue<T>(); AscendC::LocalTensor<uint32_t> srcOffsetLocal = m_queIn.DeQue<uint32_t>(); AscendC::LocalTensor<T> dstLocal = m_queOut.AllocTensor<T>(); srcLocal.SetSize(m_elementCount); AscendC::Gather(dstLocal, srcLocal, srcOffsetLocal, (uint32_t)0, m_elementCount); m_queIn.FreeTensor(srcLocal); m_queIn.FreeTensor(srcOffsetLocal); m_queOut.EnQue(dstLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<T> dstLocal = m_queOut.DeQue<T>(); AscendC::DataCopy(m_dstGlobal, dstLocal, m_elementCount); m_queOut.FreeTensor(dstLocal); } private: AscendC::TPipe m_pipe; AscendC::TQue<AscendC::QuePosition::VECIN, 1> m_queCalc; AscendC::GlobalTensor<T> m_valueGlobal; uint32_t m_concatRepeatTimes; uint32_t m_sortRepeatTimes; uint32_t m_extractRepeatTimes; uint32_t m_elementCount; AscendC::GlobalTensor<uint32_t> m_srcOffsetGlobal; AscendC::GlobalTensor<T> m_srcGlobal; AscendC::GlobalTensor<T> m_dstGlobal; AscendC::TQue<AscendC::QuePosition::VECIN, 2> m_queIn; AscendC::TQue<AscendC::QuePosition::VECOUT, 2> m_queOut; }; // class GatherTest extern "C" __global__ __aicore__ void kernel_gather(GM_ADDR dstGm, GM_ADDR srcGm, GM_ADDR srcOffsetGm) { GatherTest<half> op; op.Init(dstGm, srcGm, srcOffsetGm, 128); op.Process(); } |
Input srcOffsetLocal: [254 252 250 ... 4 2 0] Input srcLocal (128 pieces of half-type data): [0 1 2 ... 125 126 127] Output dstGlobal: [127 126 125 ... 2 1 0]