paged_attetion_inference_demo.cpp

前置条件和编译命令请参见算子调用示例。本用例仅支持Atlas 推理系列产品

与示例1相较本示例主要有以下修改点:

场景:基础场景。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include <iostream>
#include <vector>
#include <numeric>
#include <stdlib.h>
#include "acl/acl.h"
#include "atb/operation.h"
#include "atb/types.h"
#include "atb/atb_infer.h"

#include "demo_util.h"

const uint32_t NTOKENS = 2;                             // token数量
const uint32_t BATCH_SIZE = NTOKENS;                    // batch数量
const uint32_t MAX_SEQ_LEN = 1024;                      // 最大序列长度
const uint32_t HEAD_NUM = 32;                           // 头数
const uint32_t KV_HEAD_NUM = 32;                        // kv头数
const uint32_t HEAD_SIZE = 128;                         // 头大小
const uint32_t BLOCK_NUM = 16;                          // 块数量
const uint32_t BLOCK_SIZE = 128;                        // 块大小
const uint32_t MAX_CONTEXT_LEN = 1024;                  // 上下文最大长度
std::vector<int32_t> contextLensData(BATCH_SIZE, 256);  // contextLens的host侧数据

/**
 * @brief 准备atb::VariantPack中的所有输入tensor
 * @param contextPtr context指针
 * @param stream stream
 * @return atb::SVector<atb::Tensor> atb::VariantPack中的输入tensor
 * @note 需要传入所有host侧tensor
 */
atb::SVector<atb::Tensor> PrepareInTensor(atb::Context *contextPtr, aclrtStream stream)
{
    // 创建query tensor
    std::vector<unsigned int16_t> queryData(NTOKENS * HEAD_NUM * HEAD_SIZE, 0x3C00); // 0x3C00: float16的1
    atb::Tensor query = CreateTensorFromVector(
        contextPtr, stream, queryData, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {NTOKENS, HEAD_NUM, HEAD_SIZE}, ACL_FLOAT16);
    // 创建key,value tensor
    std::vector<unsigned int16_t> kvCacheData(BLOCK_NUM * BLOCK_SIZE * KV_HEAD_NUM * HEAD_SIZE, 0x3C00);
    atb::Tensor kCache = CreateTensorFromVector(contextPtr,
        stream,
        kvCacheData,
        ACL_FLOAT16,
        aclFormat::ACL_FORMAT_FRACTAL_NZ,
        {BLOCK_NUM, HEAD_SIZE *KV_HEAD_NUM / 16, BLOCK_SIZE, 16}, ACL_FLOAT16);
    atb::Tensor vCache = CreateTensorFromVector(contextPtr,
        stream,
        kvCacheData,
        ACL_FLOAT16,
        aclFormat::ACL_FORMAT_FRACTAL_NZ,
        {BLOCK_NUM, HEAD_SIZE *KV_HEAD_NUM / 16, BLOCK_SIZE, 16}, ACL_FLOAT16);
    // 创建blockTables
    uint32_t maxNumBlocksPerQuery = (MAX_CONTEXT_LEN + BLOCK_SIZE - 1) / BLOCK_SIZE;
    std::vector<int32_t> blockTablesData(NTOKENS * maxNumBlocksPerQuery, 0);
    for (size_t i = 0; i < blockTablesData.size(); i++) {
        blockTablesData.at(i) = rand() % (BLOCK_NUM - 1);
    }
    atb::Tensor blockTables = CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {NTOKENS, maxNumBlocksPerQuery});
    CHECK_STATUS(aclrtMemcpy(blockTables.deviceData,
        blockTables.dataSize,
        blockTablesData.data(),
        sizeof(int32_t) * blockTablesData.size(),
        ACL_MEMCPY_HOST_TO_DEVICE));
    // 创建contextLens,host侧tensor
    atb::Tensor contextLens = CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {BATCH_SIZE});
    contextLens.hostData = contextLensData.data();
    // 根据顺序将所有输入tensor放入SVector
    atb::SVector<atb::Tensor> inTensors = {query, kCache, vCache, blockTables, contextLens};
    return inTensors;
}

/**
 * @brief 创建一个PA的Operation,并设置参数
 * @return atb::Operation * 返回一个Operation指针
 */
atb::Operation *PrepareOperation()
{
    atb::infer::PagedAttentionParam paOpParam;
    paOpParam.headNum = HEAD_NUM;
    paOpParam.kvHeadNum = KV_HEAD_NUM;
    paOpParam.qkScale = 0.08838834764831843;
    atb::Operation *paOp = nullptr;
    CHECK_STATUS(atb::CreateOperation(paOpParam, &paOp));
    return paOp;
}

int main(int argc, char **argv)
{
    // 设置卡号、创建context、设置stream
    CHECK_STATUS(aclInit(nullptr));
    int32_t deviceId = 0;
    CHECK_STATUS(aclrtSetDevice(deviceId));
    atb::Context *context = nullptr;
    CHECK_STATUS(atb::CreateContext(&context));
    void *stream = nullptr;
    CHECK_STATUS(aclrtCreateStream(&stream));
    context->SetExecuteStream(stream);

    // PA示例
    atb::Operation *paOp = PrepareOperation();
    // 准备输入张量
    atb::VariantPack paVariantPack;
    paVariantPack.inTensors = PrepareInTensor(context, stream);  // 放入输入tensor
    atb::Tensor tensorOut = CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {NTOKENS, HEAD_NUM, HEAD_SIZE});
    paVariantPack.outTensors.push_back(tensorOut);  // 放入输出tensor

    uint64_t workspaceSize = 0;
    // 计算workspaceSize大小
    CHECK_STATUS(paOp->Setup(paVariantPack, workspaceSize, context));
    uint8_t *workspacePtr = nullptr;
    if (workspaceSize > 0) {
        CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    }
    // PA执行
    paOp->Execute(paVariantPack, workspacePtr, workspaceSize, context);
    CHECK_STATUS(aclrtSynchronizeStream(stream));  // 流同步,等待device侧任务计算完成
    CHECK_STATUS(aclrtFree(tensorOut.deviceData));
    for (atb::Tensor &inTensor : paVariantPack.inTensors) {
        CHECK_STATUS(aclrtFree(inTensor.deviceData));
    }
    if (workspaceSize > 0) {
        CHECK_STATUS(aclrtFree(workspacePtr));
    }
    CHECK_STATUS(atb::DestroyOperation(paOp));  // operation,对象概念,先释放
    CHECK_STATUS(aclrtDestroyStream(stream));
    CHECK_STATUS(DestroyContext(context));  // context,全局资源,后释放
    CHECK_STATUS((aclFinalize()));
    std::cout << "PA demo success!" << std::endl;
    return 0;
}