昇腾社区首页
中文
注册

PagedCacheLoadOperation

定义

atb::Status AtbPagedCacheLoadGetWorkspaceSize(const aclTensor *keyCache, const aclTensor *valueCache,
                                              const aclTensor *blockTables, const aclTensor *contextLens,
                                              const aclTensor *key, const aclTensor *value, const aclTensor *seqStarts,
                                              int8_t kvCacheCfg, bool isSeqLensCumsumType, bool hasSeqStarts,
                                              uint64_t *workspaceSize, atb::Operation **op, atb::Context *context);
atb::Status AtbPagedCacheLoad(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);

AtbPagedCacheLoadGetWorkspaceSize成员

参数

标量/张量

维度(keyCache/valueCache格式为ND)

维度(keyCache/valueCache格式为NZ)

数据类型

格式

默认值

是否必选

描述

keyCache

张量

[num_blocks, block_size, num_heads, head_size_k]

[num_blocks, num_heads * head_size_k // elenum_aligned, block_size, elenum_aligned]

float16/bf16/int8

ND/NZ

-

输入tensor。

int8时 :elenum_aligned=32,其他情况为16。

valueCache

张量

[num_blocks, block_size, num_heads, head_size_v]

[num_blocks, num_heads * head_size_v // elenum_aligned, block_size, elenum_aligned]

float16/bf16/int8

ND/NZ

-

输入tensor。

int8时 :elenum_aligned=32,其他情况为16。

blockTables

张量

[batch, block_indices]

[len(contextLens), (max(contextLens) - 1) // block_size + 1]

int32

ND

-

输入tensor。

len(contextLens)为contextlens的长度。

contextLens

张量

[batch]or[batch+1]

[len(contextLens)]

int32

ND

-

输入tensor。

len(contextLens)为contextlens的长度。

key

张量

[num_tokens, num_heads, head_size_k]

[sum(contextLens), num_heads * head_size_k]

float16/bf16/int8

ND

-

输入/输出tensor。

sum(contextLens)为contextlens的各元素求和。

value

张量

[num_tokens, num_heads, head_size_v]

[sum(contextLens), num_heads * head_size_v]

float16/bf16/int8

ND

-

输入/输出tensor。

sum(contextLens)为contextlens的各元素求和 。

seqStarts

张量

[batch]

-

int32

ND

-

可选输入,每个batch在blocktable中对应的起始位置,只支持ND格式。

kvCacheCfg

标量

-

-

int8_t

-

0

  • K_CACHE_V_CACHE_NZ = 0:默认值:传入key_cache和value_cache, 且为NZ格式 。
  • K_CACHE_V_CACHE_ND= 1:传入key_cache和value_cache,且为ND格式。

isSeqLensCumsumType

标量

-

-

bool

-

false

支持输入seqLens为累加和模式,即第n个batch中所需提取的元素数量为前n个batch的所需提取元素数量的累加和,只支持ND格式。

hasSeqStarts

标量

-

-

bool

-

false

提供SeqStart用于为每个batch提供在blockTable中的初始位置(类似于offset) ,只支持ND格式。