昇腾社区首页
中文
注册

BlockCopyOperation(代码开放)

产品支持情况

硬件型号

是否支持

特殊说明

Atlas A3 推理系列产品/Atlas A3 训练系列产品

-

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

-

Atlas 训练系列产品

x

-

Atlas 推理系列产品

只支持float16数据类型。

Atlas 200I/500 A2 推理产品

x

-

功能说明

通过指定源block index和目的block index,将keyCache和valueCache中的数据从一系列源block拷贝到一系列目的block(keyCache和valueCache做相同的拷贝处理)。

算子上下文

图1 BlockCopyOperation

定义

1
2
3
struct BlockCopyParam {
    uint8_t rsv[16] = {0};
};

参数列表

成员名称

类型

默认值

描述

rsv[16]

uint8_t

{0}

预留参数。

输入

参数

维度

数据类型

格式

描述

keyCache

  • ND:[num_blocks, block_size, num_head, head_size]
  • NZ:[blockCount,numHead*headSize/16, BlockSize, 16]

float16/bf16/int8

ND/NZ

key矩阵。

valueCache

  • ND:[num_blocks, block_size, num_head, head_size]
  • NZ:[blockCount,numHead*headSize/16, BlockSize, 16]

float16/bf16/int8

ND/NZ

value矩阵。

数据类型和格式应该与keyCache保持一致。

srcBlockIndices

[src_count],最长为[num_blocks]

int32

ND

keyCache、valueCache源位置block索引,值域范围为[0,num_blocks)。

dstBlockIndices

[dst_count],最长为[num_blocks]

int32

ND

keyCache、valueCache目标位置block索引,dstBlockIndices[cumSum[i-1]:cumSum[i]-1]为srcBlockIndices[i]对应的目标位置List(特殊的,dstBlockIndices[0:cumSum[0]-1]为srcBlockIndices[0]对应的目标位置List),值域范围为[0,num_blocks)。

cumSum

[src_count]

int32

ND

cumSum[i]为源位置srcBlockIndices[i]对应的目标位置List在dstBlockIndices中的结束位置。

输出

参数

维度

数据类型

格式

描述

keyCacheNew

[num_blocks, block_size, num_head, head_size]

float16/bf16/int8

ND/NZ

block拷贝后的key矩阵,作为输出。输出keyCacheNew与输入keyCache指向同一地址,即进行原地修改。

数据类型和格式应该与keyCache保持一致。

valueCacheNew

[num_blocks, block_size, num_head, head_size]

float16/bf16/int8

ND/NZ

block拷贝后的value矩阵,作为输出。输出valueCacheNew与输入valueCache指向同一地址,即进行原地修改。

数据类型和格式应该与valueCache保持一致。

约束说明

  • 一个block index不能同时出现在srcBlockIndices和dstBlockIndices中。
  • srcBlockIndices中的block index不重复,如有重复则应合并为一项,其对应的dstBlockIndices取并集。
  • dstBlockIndices中的block index不重复,即不能有两个不同的block拷贝到同一个block中。
  • 每个srcBlockIndex至少有一个dstBlockIndex与之对应(i > 0时,cumSum[i] - cumSum[i - 1] ≥ 1;cumSum[0] ≥ 1)。
  • 根据前述4条规则,srcBlockIndices和dstBlockIndices的size应满足:dst_count ≥ src_count,src_count + dst_count ≤ num_blocks。
  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品Atlas A3 推理系列产品/Atlas A3 训练系列产品上仅支持ND格式。
  • Atlas 推理系列产品上,有如下约束:
    • Format为ND时,keyCache/valueCache后三个维度的乘积与16对齐(每个Block总字节数与32B对齐,float16场景,每个数据占用2字节)。
    • Format为NZ时,BlockSize需与16对齐。