Scatter(ISASI)
产品支持情况
|
产品 |
是否支持 |
|---|---|
|
Atlas 350 加速卡 |
√ |
|
|
x |
|
|
x |
|
|
√ |
|
|
√ |
|
|
x |
|
|
x |
功能说明
给定一个连续的输入张量和一个目的地址偏移张量,Scatter指令根据偏移地址生成新的结果张量后将输入张量分散到结果张量中。
将源操作数src中的元素按照指定的位置(由dst_offset和base_addr共同作用)分散到目的操作数dst中。
函数原型
- tensor前n个数据计算
1 2
template <typename T> __aicore__ inline void Scatter(const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<uint32_t>& dstOffset, const uint32_t dstBaseAddr, const uint32_t count)
- tensor高维切分计算
- mask逐bit模式
1 2
template <typename T> __aicore__ inline void Scatter(const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<uint32_t>& dstOffset, const uint32_t dstBaseAddr, const uint64_t mask[], const uint8_t repeatTime, const uint8_t srcRepStride)
- mask连续模式
1 2
template <typename T> __aicore__ inline void Scatter(const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<uint32_t>& dstOffset, const uint32_t dstBaseAddr, const uint64_t mask, const uint8_t repeatTime, const uint8_t srcRepStride)
- mask逐bit模式
参数说明
|
参数名 |
描述 |
|---|---|
|
T |
操作数数据类型。 Atlas 350 加速卡,支持的数据类型为:uint8_t/int8_t/uint16_t/int16_t/half/bfloat16_t/uint32_t/int32_t/float/uint64_t/int64_t |
|
参数名称 |
输入/输出 |
含义 |
|---|---|---|
|
dst |
输出 |
目的操作数,类型为LocalTensor。LocalTensor的起始地址需要32字节对齐。 |
|
src |
输入 |
源操作数,类型为LocalTensor。数据类型需与dst保持一致。 |
|
dstOffset |
输入 |
用于存储源操作数的每个元素在dst中对应的地址偏移。偏移基于dst的基地址dstBaseAddr计算,以字节为单位,取值应保证按dst数据类型位宽对齐,否则会导致非预期行为。 针对以下型号,地址偏移的取值范围不超出uint32_t的范围即可。 针对以下型号,地址偏移的取值范围如下:当操作数为8位时,取值范围为[0, 216-1];当操作数为16位时,取值范围为[0, 217-1],当操作数为32位或者64位时,不超过uint32_t的范围即可。超出取值范围可能导致非预期输出。 Atlas 350 加速卡 |
|
dstBaseAddr |
输入 |
dst的起始偏移地址,单位是字节。取值应保证按dst数据类型位宽对齐,否则会导致非预期行为。 |
|
count |
输入 |
执行处理的数据个数。 |
|
mask/mask[] |
输入 |
|
|
repeatTime |
输入 |
指令迭代次数,每次迭代完成8个datablock的数据收集,数据范围:repeatTime∈[0,255]。
特别地,针对以下型号:
操作数为8位时,每次迭代完成4个datablock(32Bytes)的数据收集。 |
|
srcRepStride |
输入 |
相邻迭代间的地址步长,单位是datablock。 |
约束说明
- dstOffset中的偏移地址不能有相同值,如果存在2个或者多个偏移重复的情况,行为是不可预期的。
- 针对Atlas 350 加速卡,uint8_t/int8_t数据类型仅支持tensor前n个数据计算接口。
调用示例
1 2 |
uint32_t m_elementCount = 128 AscendC::Scatter(dstLocal, srcLocal, dstOffsetLocal, (uint32_t)0, m_elementCount); // dstOffsetLocal 用于存储源操作数的每个元素在dst中对应的地址偏移 |
输入数据dstOffsetLocal: [254 252 250 ... 4 2 0] 输入数据srcLocal(128个half类型数据): [0 1 2 ... 125 126 127] 输出数据dstGlobal: [127 126 125 ... 2 1 0]
