1 2 |
template <typename T, MaskMode mode = MaskMode::NORMAL> __aicore__ static inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow) |
1 2 |
template <typename T, MaskMode mode = MaskMode::NORMAL> __aicore__ static inline void SetVectorMask(int32_t len) |
参数名 |
描述 |
---|---|
T |
矢量计算操作数数据类型。 |
mode |
mask模式,MaskMode类型,取值如下:
|
参数名 |
输入/输出 |
描述 |
---|---|---|
maskHigh |
输入 |
Normal模式:对应Normal模式下的逐比特模式,可以按位控制哪些元素参与计算。传入高位mask值。 Counter模式:需要置0,本入参不生效。 |
maskLow |
输入 |
Normal模式:对应Normal模式下的逐比特模式,可以按位控制哪些元素参与计算。传入低位mask值。 Counter模式:整个矢量计算过程中,参与计算的元素个数。 |
len |
输入 |
Normal模式:对应Normal模式下的mask连续模式,表示单次迭代内表示前面连续的多少个元素参与计算。 Counter模式:整个矢量计算过程中,参与计算的元素个数。 |
无
Atlas 训练系列产品
Atlas推理系列产品AI Core
Atlas推理系列产品Vector Core
Atlas A2训练系列产品/Atlas 800I A2推理产品
Atlas 200I/500 A2推理产品
该接口仅在矢量计算API的isSetMask模板参数为false时生效,使用完成后需要使用ResetMask将mask恢复为默认值。
可结合SetMaskCount与SetMaskNorm使用,先设置mask的模式再设置mask:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
AscendC::LocalTensor<half> dstLocal; AscendC::LocalTensor<half> src0Local; AscendC::LocalTensor<half> src1Local; // Normal模式 AscendC::SetMaskNorm(); AscendC::SetVectorMask<half, AscendC::MaskMode::NORMAL>(0xffffffffffffffff, 0xffffffffffffffff); // 逐bit模式 // SetVectorMask<half, MaskMode::NORMAL>(128); // 连续模式 // 多次调用矢量计算API, 可以统一设置为Normal模式,并设置mask参数,无需在API内部反复设置,省去了在API反复设置的过程,会有一定的性能优势 // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内数据连续读取和写入 // dstRepStride, src0RepStride, src1RepStride = 8, 相邻迭代间数据连续读取和写入 AscendC::Add<half, false>(dstLocal, src0Local, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 }); AscendC::Sub<half, false>(src0Local, dstLocal, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 }); AscendC::Mul<half, false>(src1Local, dstLocal, src0Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 }); AscendC::ResetMask(); |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
// Counter模式和tensor高维切分计算接口配合使用 AscendC::LocalTensor<half> dstLocal; AscendC::LocalTensor<half> src0Local; AscendC::LocalTensor<half> src1Local; int32_t len = 128; // 参与计算的元素个数 AscendC::SetMaskCount(); AscendC::SetVectorMask<half, AscendC::MaskMode::COUNTER>(len); // SetVectorMask<half, MaskMode::COUNTER>(0, len); AscendC::Add<half, false>(dstLocal, src0Local, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 }); AscendC::Sub<half, false>(src0Local, dstLocal, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 }); AscendC::Mul<half, false>(src1Local, dstLocal, src0Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 }); AscendC::SetMaskNorm(); AscendC::ResetMask(); // Counter模式和tensor前n个数据计算接口配合使用 AscendC::LocalTensor<half> dstLocal; AscendC::LocalTensor<half> src0Local; half num = 2; AscendC::SetMaskCount(); AscendC::SetVectorMask<half, AscendC::MaskMode::COUNTER>(128); // 参与计算的元素个数为128 AscendC::Adds<half, false>(dstLocal, src0Local, num, 1); AscendC::Muls<half, false>(dstLocal, src0Local, num, 1); AscendC::SetMaskNorm(); AscendC::ResetMask(); |