paged_attention_demo.cpp

前置条件和编译命令请参见算子调用示例。当前仅支持Atlas 800I A2 推理产品/Atlas A2 训练系列产品Atlas A3 推理系列产品/Atlas A3 训练系列产品

场景:mask场景。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <iostream>
#include <vector>
#include <numeric>
#include <stdlib.h>
#include "acl/acl.h"
#include "atb/operation.h"
#include "atb/types.h"
#include "atb/atb_infer.h"

#include "demo_util.h"

const uint32_t NTOKENS = 2;                             // token数量
const uint32_t BATCH_SIZE = NTOKENS;                    // batch数量
const uint32_t MAX_SEQ_LEN = 1024;                      // 最大序列长度
const uint32_t HEAD_NUM = 32;                           // 头数
const uint32_t KV_HEAD_NUM = 32;                        // kv头数
const uint32_t HEAD_SIZE = 128;                         // 头大小
const uint32_t BLOCK_NUM = 16;                          // 块数量
const uint32_t BLOCK_SIZE = 128;                        // 块大小
const uint32_t MAX_CONTEXT_LEN = 1024;                  // 上下文最大长度
std::vector<int32_t> contextLensData(BATCH_SIZE, 256);  // contextLens的host侧数据

/**
 * @brief 准备atb::VariantPack中的所有输入tensor
 * @param contextPtr context指针
 * @param stream stream
 * @return atb::SVector<atb::Tensor> atb::VariantPack中的输入tensor
 * @note 需要传入所有host侧tensor
 */
atb::SVector<atb::Tensor> PrepareInTensor(atb::Context *contextPtr, aclrtStream stream)
{
    // 创建query tensor
    std::vector<float> queryData(NTOKENS * HEAD_NUM * HEAD_SIZE, 1.0);
    atb::Tensor query = CreateTensorFromVector(
        contextPtr, stream, queryData, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {NTOKENS, HEAD_NUM, HEAD_SIZE});
    // 创建key,value tensor
    std::vector<float> kvCacheData(BLOCK_NUM * BLOCK_SIZE * KV_HEAD_NUM * HEAD_SIZE, 1.0);
    atb::Tensor kCache = CreateTensorFromVector(contextPtr,
        stream,
        kvCacheData,
        ACL_FLOAT16,
        aclFormat::ACL_FORMAT_ND,
        {BLOCK_NUM, BLOCK_SIZE, KV_HEAD_NUM, HEAD_SIZE});
    atb::Tensor vCache = CreateTensorFromVector(contextPtr,
        stream,
        kvCacheData,
        ACL_FLOAT16,
        aclFormat::ACL_FORMAT_ND,
        {BLOCK_NUM, BLOCK_SIZE, KV_HEAD_NUM, HEAD_SIZE});
    // 创建blockTables
    uint32_t maxNumBlocksPerQuery = (MAX_CONTEXT_LEN + BLOCK_SIZE - 1) / BLOCK_SIZE;
    std::vector<int32_t> blockTablesData(NTOKENS * maxNumBlocksPerQuery, 0);
    for (size_t i = 0; i < blockTablesData.size(); i++) {
        blockTablesData.at(i) = rand() % (BLOCK_NUM - 1);
    }
    atb::Tensor blockTables = CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {NTOKENS, maxNumBlocksPerQuery});
    CHECK_STATUS(aclrtMemcpy(blockTables.deviceData,
        blockTables.dataSize,
        blockTablesData.data(),
        sizeof(int32_t) * blockTablesData.size(),
        ACL_MEMCPY_HOST_TO_DEVICE));
    // 创建contextLens,host侧tensor
    atb::Tensor contextLens = CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {BATCH_SIZE});
    contextLens.hostData = contextLensData.data();
    // 创建norm mask,值为-inf的上三角mask
    std::vector<float> maskData(BATCH_SIZE * MAX_SEQ_LEN, 0);
    for (int i = 0; i < BATCH_SIZE; ++i) {
        for (int j = 0; j < MAX_SEQ_LEN; ++j) {
            maskData[i * MAX_SEQ_LEN + j] = -32768;  // 32768 : -inf
        }
    }
    atb::Tensor mask = CreateTensorFromVector(
        contextPtr, stream, maskData, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {BATCH_SIZE, 1, MAX_SEQ_LEN});
    // 根据顺序将所有输入tensor放入SVector
    atb::SVector<atb::Tensor> inTensors = {query, kCache, vCache, blockTables, contextLens, mask};
    return inTensors;
}

/**
 * @brief 创建一个PA的Operation,并设置参数
 * @return atb::Operation * 返回一个Operation指针
 */
atb::Operation *PrepareOperation()
{
    atb::infer::PagedAttentionParam paOpParam;
    paOpParam.maskType = atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_NORM;
    paOpParam.headNum = HEAD_NUM;
    paOpParam.kvHeadNum = KV_HEAD_NUM;
    paOpParam.qkScale = 0.08838834764831843;
    atb::Operation *paOp = nullptr;
    CHECK_STATUS(atb::CreateOperation(paOpParam, &paOp));
    return paOp;
}

int main(int argc, char **argv)
{
    // 设置卡号、创建context、设置stream
    CHECK_STATUS(aclInit(nullptr));
    int32_t deviceId = 0;
    CHECK_STATUS(aclrtSetDevice(deviceId));
    atb::Context *context = nullptr;
    CHECK_STATUS(atb::CreateContext(&context));
    void *stream = nullptr;
    CHECK_STATUS(aclrtCreateStream(&stream));
    context->SetExecuteStream(stream);

    // PA示例
    atb::Operation *paOp = PrepareOperation();
    // 准备输入张量
    atb::VariantPack paVariantPack;
    paVariantPack.inTensors = PrepareInTensor(context, stream);  // 放入输入tensor
    atb::Tensor tensorOut = CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {NTOKENS, HEAD_NUM, HEAD_SIZE});
    paVariantPack.outTensors.push_back(tensorOut);  // 放入输出tensor

    uint64_t workspaceSize = 0;
    // 计算workspaceSize大小
    CHECK_STATUS(paOp->Setup(paVariantPack, workspaceSize, context));
    uint8_t *workspacePtr = nullptr;
    if (workspaceSize > 0) {
        CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    }
    // PA执行
    paOp->Execute(paVariantPack, workspacePtr, workspaceSize, context);
    CHECK_STATUS(aclrtSynchronizeStream(stream));  // 流同步,等待device侧任务计算完成
    CHECK_STATUS(aclrtFree(tensorOut.deviceData));
    for (atb::Tensor &inTensor : paVariantPack.inTensors) {
        CHECK_STATUS(aclrtFree(inTensor.deviceData));
    }
    if (workspaceSize > 0) {
        CHECK_STATUS(aclrtFree(workspacePtr));
    }
    CHECK_STATUS(atb::DestroyOperation(paOp));  // operation,对象概念,先释放
    CHECK_STATUS(aclrtDestroyStream(stream));
    CHECK_STATUS(DestroyContext(context));  // context,全局资源,后释放
    CHECK_STATUS((aclFinalize()));
    std::cout << "PA demo success!" << std::endl;
    return 0;
}