Aclnn算子接入ATB Plugin介绍

概述

Aclnn算子接入ATB Plugin主要思路如下:
  1. 创建自定义类,该类继承ATB的Operation类。
  2. 使用aclnn接口实现Operation类中几个特定接口。

本节将主要围绕思路中的第二点对Aclnn算子接入ATB Plugin进行详细介绍。

图1 自定义Operation类继承自ATB Operation类
表1 Operation类对外提供的接口

接口名

自定义Operation类是否需要实现

Setup

Execute

InferShape

GetInputNum

GetOutputNum

GetName

Setup接口的实现

Setup接口的主要目的有两个:

atb::Status AclnnBaseOperation::Setup(
    const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context)
{
    LOG_INFO(opName_ + " setup start");

    // 调用子类,创建输入输出tensor,并存入VariantPack
    int ret = CreateAclnnVariantPack(variantPack);
    if (ret != 0)
    {
        LOG_ERROR(opName_ + " call CreateAclnnVariantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_INVALID_PARAM;
    }

    // 调用子类,子类中调用aclnn接口,获取Executor和Workspace
    ret = SetAclnnWorkspaceExecutor();
    if (ret != 0)
    {
        LOG_ERROR(
            opName_ + " call CreateAclnnVaSetAclnnWorkspaceExecutorriantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_INVALID_PARAM;
    }
    // 返回计算出的workspaceSize
    workspaceSize = workspaceSize_;
    LOG_INFO(opName_ + " setup end");
    return ret;
}

Execute接口的实现

Execute接口的主要功能是根据用户给到的显存地址,去进行对应算子任务的下发。在Aclnn算子接入ATB Plugin的流程中,这一步主要负责设置调用Aclnn的接口来进行输入设置与算子下发。

atb::Status AclnnBaseOperation::Execute(
    const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, atb::Context *context)
{
    LOG_INFO(opName_ + " execute start");
    if (!context) {
        LOG_ERROR(opName_ + " execute fail, context param is null");
        return atb::ERROR_INVALID_PARAM;
    }
    // 获取执行stream
    aclrtStream stream = context->GetExecuteStream();
    if (!stream) {
        LOG_ERROR(opName_ + " execute fail, execute stream in context is null");
        return atb::ERROR_INVALID_PARAM;
    }
    // 更新数据传入的地址
    int ret = UpdateAclnnVariantPack(variantPack);
    if (ret != 0) {
        LOG_ERROR(opName_ + " call UpdateAclnnVariantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_CANN_ERROR;
    }
    LOG_INFO("Input workspaceSize " + std::to_string(workspaceSize) + " localCache workspaceSize " +
             std::to_string(workspaceSize_));
    ret = ExecuteAclnnOp(workspace, stream); // 调用aclnn接口
    if (ret != 0) {
        LOG_ERROR(opName_ + " call ExecuteAclnnOp fail, error: " + std::to_string(ret));
        return atb::ERROR_CANN_ERROR;
    }
    LOG_INFO(opName_ + " execute start");
    return ret;
}

InferShape接口

该接口的主要功能为Shape推导,根据输入tensor的Shape及算子参数推导出输出tensor的Shape。该函数实现只与算子本身的特性及接口相关。

atb::Status GeluOperation::InferShape(
    const atb::SVector<atb::TensorDesc> &inTensorDesc, atb::SVector<atb::TensorDesc> &outTensorDesc) const
{
    LOG_INFO(opName_ + " InferShape start");
    outTensorDesc.at(0).format = inTensorDesc.at(0).format;
    outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype;
    outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum;

    if (inTensorDesc.at(0).shape.dimNum == DIM3)
    {
        LOG_INFO("[input0 dimNum = 3] CHECK " + opName_ + " input shape: [input0] " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM0]) + ", " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM1]) + ", " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM2]));
        outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0];
        outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1];
        outTensorDesc.at(0).shape.dims[DIM2] = inTensorDesc.at(0).shape.dims[DIM2];
    }
    else if (inTensorDesc.at(0).shape.dimNum == DIM2)
    {
        LOG_INFO("[input0 dimNum = 2] CHECK " + opName_ + " input shape: [input0] " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM0]) + ", " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM1]));
        outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0];
        outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1];
    }
    else
    {
        LOG_ERROR(opName_ + " invalid dimNum = " + std::to_string(inTensorDesc.at(0).shape.dimNum));
    }

    LOG_INFO(opName_ + " InferShape end");
    return atb::NO_ERROR;
}

GetInputNum接口与GetOutputNum接口

这两个接口的主要功能分别为返回算子所需的输入tensor个数及算子所需的输出tensor个数,与InferShape类似,也只与算子本身特性及接口相关。

uint32_t GeluOperation::GetInputNum() const
{
    return 1; // gelu入参个数
}

uint32_t GeluOperation::GetOutputNum() const
{
    return 1; // gelu出参个数
}