Aclnn算子接入ATB Plugin介绍
概述
Aclnn算子接入ATB Plugin主要思路如下:
- 创建自定义类,该类继承ATB的Operation类。
 - 使用aclnn接口实现Operation类中几个特定接口。
 
本节将主要围绕思路中的第二点对Aclnn算子接入ATB Plugin进行详细介绍。
图1 自定义Operation类继承自ATB Operation类


接口名  | 
自定义Operation类是否需要实现  | 
|---|---|
Setup  | 
是  | 
Execute  | 
是  | 
InferShape  | 
是  | 
GetInputNum  | 
是  | 
GetOutputNum  | 
是  | 
GetName  | 
是  | 
Setup接口的实现
Setup接口的主要目的有两个:
- 做好下发前需要的host数据的准备。例如一些接口调用必需的数据结构、tiling数据等 。
 - 计算出算子执行需要的显存大小并返回。在Aclnn接入ATB Plugin的流程中,这一步主要是完成tensor结构体的创建及调用对应aclnn算子的aclnnGetWorkSpaceSize函数。
 
atb::Status AclnnBaseOperation::Setup(
    const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context)
{
    LOG_INFO(opName_ + " setup start");
    // 调用子类,创建输入输出tensor,并存入VariantPack
    int ret = CreateAclnnVariantPack(variantPack);
    if (ret != 0)
    {
        LOG_ERROR(opName_ + " call CreateAclnnVariantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_INVALID_PARAM;
    }
    // 调用子类,子类中调用aclnn接口,获取Executor和Workspace
    ret = SetAclnnWorkspaceExecutor();
    if (ret != 0)
    {
        LOG_ERROR(
            opName_ + " call CreateAclnnVaSetAclnnWorkspaceExecutorriantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_INVALID_PARAM;
    }
    // 返回计算出的workspaceSize
    workspaceSize = workspaceSize_;
    LOG_INFO(opName_ + " setup end");
    return ret;
}
Execute接口的实现
Execute接口的主要功能是根据用户给到的显存地址,去进行对应算子任务的下发。在Aclnn算子接入ATB Plugin的流程中,这一步主要负责设置调用Aclnn的接口来进行输入设置与算子下发。
atb::Status AclnnBaseOperation::Execute(
    const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, atb::Context *context)
{
    LOG_INFO(opName_ + " execute start");
    if (!context) {
        LOG_ERROR(opName_ + " execute fail, context param is null");
        return atb::ERROR_INVALID_PARAM;
    }
    // 获取执行stream
    aclrtStream stream = context->GetExecuteStream();
    if (!stream) {
        LOG_ERROR(opName_ + " execute fail, execute stream in context is null");
        return atb::ERROR_INVALID_PARAM;
    }
    // 更新数据传入的地址
    int ret = UpdateAclnnVariantPack(variantPack);
    if (ret != 0) {
        LOG_ERROR(opName_ + " call UpdateAclnnVariantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_CANN_ERROR;
    }
    LOG_INFO("Input workspaceSize " + std::to_string(workspaceSize) + " localCache workspaceSize " +
             std::to_string(workspaceSize_));
    ret = ExecuteAclnnOp(workspace, stream); // 调用aclnn接口
    if (ret != 0) {
        LOG_ERROR(opName_ + " call ExecuteAclnnOp fail, error: " + std::to_string(ret));
        return atb::ERROR_CANN_ERROR;
    }
    LOG_INFO(opName_ + " execute start");
    return ret;
}
InferShape接口
该接口的主要功能为Shape推导,根据输入tensor的Shape及算子参数推导出输出tensor的Shape。该函数实现只与算子本身的特性及接口相关。
atb::Status GeluOperation::InferShape(
    const atb::SVector<atb::TensorDesc> &inTensorDesc, atb::SVector<atb::TensorDesc> &outTensorDesc) const
{
    LOG_INFO(opName_ + " InferShape start");
    outTensorDesc.at(0).format = inTensorDesc.at(0).format;
    outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype;
    outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum;
    if (inTensorDesc.at(0).shape.dimNum == DIM3)
    {
        LOG_INFO("[input0 dimNum = 3] CHECK " + opName_ + " input shape: [input0] " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM0]) + ", " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM1]) + ", " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM2]));
        outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0];
        outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1];
        outTensorDesc.at(0).shape.dims[DIM2] = inTensorDesc.at(0).shape.dims[DIM2];
    }
    else if (inTensorDesc.at(0).shape.dimNum == DIM2)
    {
        LOG_INFO("[input0 dimNum = 2] CHECK " + opName_ + " input shape: [input0] " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM0]) + ", " +
                 std::to_string(inTensorDesc.at(0).shape.dims[DIM1]));
        outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0];
        outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1];
    }
    else
    {
        LOG_ERROR(opName_ + " invalid dimNum = " + std::to_string(inTensorDesc.at(0).shape.dimNum));
    }
    LOG_INFO(opName_ + " InferShape end");
    return atb::NO_ERROR;
}
GetInputNum接口与GetOutputNum接口
这两个接口的主要功能分别为返回算子所需的输入tensor个数及算子所需的输出tensor个数,与InferShape类似,也只与算子本身特性及接口相关。
uint32_t GeluOperation::GetInputNum() const
{
    return 1; // gelu入参个数
}
uint32_t GeluOperation::GetOutputNum() const
{
    return 1; // gelu出参个数
}
父主题: ATB使用示例