PagedCacheLoadOperation
定义
atb::Status AtbPagedCacheLoadGetWorkspaceSize(const aclTensor *keyCache, const aclTensor *valueCache, const aclTensor *blockTables, const aclTensor *contextLens, const aclTensor *key, const aclTensor *value, const aclTensor *seqStarts, int8_t kvCacheCfg, bool isSeqLensCumsumType, bool hasSeqStarts, uint64_t *workspaceSize, atb::Operation **op, atb::Context *context); atb::Status AtbPagedCacheLoad(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbPagedCacheLoadGetWorkspaceSize成员
参数 |
标量/张量 |
维度(keyCache/valueCache格式为ND) |
维度(keyCache/valueCache格式为NZ) |
数据类型 |
格式 |
默认值 |
是否必选 |
描述 |
---|---|---|---|---|---|---|---|---|
keyCache |
张量 |
[num_blocks, block_size, num_heads, head_size_k] |
[num_blocks, num_heads * head_size_k // elenum_aligned, block_size, elenum_aligned] |
float16/bf16/int8 |
ND/NZ |
- |
是 |
输入tensor。 int8时 :elenum_aligned=32,其他情况为16。 |
valueCache |
张量 |
[num_blocks, block_size, num_heads, head_size_v] |
[num_blocks, num_heads * head_size_v // elenum_aligned, block_size, elenum_aligned] |
float16/bf16/int8 |
ND/NZ |
- |
是 |
输入tensor。 int8时 :elenum_aligned=32,其他情况为16。 |
blockTables |
张量 |
[batch, block_indices] |
[len(contextLens), (max(contextLens) - 1) // block_size + 1] |
int32 |
ND |
- |
是 |
输入tensor。 len(contextLens)为contextlens的长度。 |
contextLens |
张量 |
[batch]or[batch+1] |
[len(contextLens)] |
int32 |
ND |
- |
是 |
输入tensor。 len(contextLens)为contextlens的长度。 |
key |
张量 |
[num_tokens, num_heads, head_size_k] |
[sum(contextLens), num_heads * head_size_k] |
float16/bf16/int8 |
ND |
- |
是 |
输入/输出tensor。 sum(contextLens)为contextlens的各元素求和。 |
value |
张量 |
[num_tokens, num_heads, head_size_v] |
[sum(contextLens), num_heads * head_size_v] |
float16/bf16/int8 |
ND |
- |
是 |
输入/输出tensor。 sum(contextLens)为contextlens的各元素求和 。 |
seqStarts |
张量 |
[batch] |
- |
int32 |
ND |
- |
否 |
可选输入,每个batch在blocktable中对应的起始位置,只支持ND格式。 |
kvCacheCfg |
标量 |
- |
- |
int8_t |
- |
0 |
是 |
|
isSeqLensCumsumType |
标量 |
- |
- |
bool |
- |
false |
否 |
支持输入seqLens为累加和模式,即第n个batch中所需提取的元素数量为前n个batch的所需提取元素数量的累加和,只支持ND格式。 |
hasSeqStarts |
标量 |
- |
- |
bool |
- |
false |
否 |
提供SeqStart用于为每个batch提供在blockTable中的初始位置(类似于offset) ,只支持ND格式。 |