Aclnn算子接入ATB Plugin介绍
概述
Aclnn算子接入ATB Plugin主要思路如下:
- 创建自定义类,该类继承ATB的Operation类。
- 使用aclnn接口实现Operation类中几个特定接口。
本节将主要围绕思路中的第二点对Aclnn算子接入ATB Plugin进行详细介绍。
图1 自定义Operation类继承自ATB Operation类


接口名 |
自定义Operation类是否需要实现 |
|---|---|
Setup |
是 |
Execute |
是 |
InferShape |
是 |
GetInputNum |
是 |
GetOutputNum |
是 |
GetName |
是 |
Setup接口的实现
Setup接口的主要目的有两个:
- 做好下发前需要的host数据的准备。例如一些接口调用必需的数据结构、tiling数据等 。
- 计算出算子执行需要的显存大小并返回。在Aclnn接入ATB Plugin的流程中,这一步主要是完成tensor结构体的创建及调用对应aclnn算子的aclnnGetWorkSpaceSize函数。
atb::Status AclnnBaseOperation::Setup(
const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context)
{
LOG_INFO(opName_ + " setup start");
// 调用子类,创建输入输出tensor,并存入VariantPack
int ret = CreateAclnnVariantPack(variantPack);
if (ret != 0)
{
LOG_ERROR(opName_ + " call CreateAclnnVariantPack fail, error: " + std::to_string(ret));
return atb::ERROR_INVALID_PARAM;
}
// 调用子类,子类中调用aclnn接口,获取Executor和Workspace
ret = SetAclnnWorkspaceExecutor();
if (ret != 0)
{
LOG_ERROR(
opName_ + " call CreateAclnnVaSetAclnnWorkspaceExecutorriantPack fail, error: " + std::to_string(ret));
return atb::ERROR_INVALID_PARAM;
}
// 返回计算出的workspaceSize
workspaceSize = workspaceSize_;
LOG_INFO(opName_ + " setup end");
return ret;
}
Execute接口的实现
Execute接口的主要功能是根据用户给到的显存地址,去进行对应算子任务的下发。在Aclnn算子接入ATB Plugin的流程中,这一步主要负责设置调用Aclnn的接口来进行输入设置与算子下发。
atb::Status AclnnBaseOperation::Execute(
const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, atb::Context *context)
{
LOG_INFO(opName_ + " execute start");
if (!context) {
LOG_ERROR(opName_ + " execute fail, context param is null");
return atb::ERROR_INVALID_PARAM;
}
// 获取执行stream
aclrtStream stream = context->GetExecuteStream();
if (!stream) {
LOG_ERROR(opName_ + " execute fail, execute stream in context is null");
return atb::ERROR_INVALID_PARAM;
}
// 更新数据传入的地址
int ret = UpdateAclnnVariantPack(variantPack);
if (ret != 0) {
LOG_ERROR(opName_ + " call UpdateAclnnVariantPack fail, error: " + std::to_string(ret));
return atb::ERROR_CANN_ERROR;
}
LOG_INFO("Input workspaceSize " + std::to_string(workspaceSize) + " localCache workspaceSize " +
std::to_string(workspaceSize_));
ret = ExecuteAclnnOp(workspace, stream); // 调用aclnn接口
if (ret != 0) {
LOG_ERROR(opName_ + " call ExecuteAclnnOp fail, error: " + std::to_string(ret));
return atb::ERROR_CANN_ERROR;
}
LOG_INFO(opName_ + " execute start");
return ret;
}
InferShape接口
该接口的主要功能为Shape推导,根据输入tensor的Shape及算子参数推导出输出tensor的Shape。该函数实现只与算子本身的特性及接口相关。
atb::Status GeluOperation::InferShape(
const atb::SVector<atb::TensorDesc> &inTensorDesc, atb::SVector<atb::TensorDesc> &outTensorDesc) const
{
LOG_INFO(opName_ + " InferShape start");
outTensorDesc.at(0).format = inTensorDesc.at(0).format;
outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype;
outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum;
if (inTensorDesc.at(0).shape.dimNum == DIM3)
{
LOG_INFO("[input0 dimNum = 3] CHECK " + opName_ + " input shape: [input0] " +
std::to_string(inTensorDesc.at(0).shape.dims[DIM0]) + ", " +
std::to_string(inTensorDesc.at(0).shape.dims[DIM1]) + ", " +
std::to_string(inTensorDesc.at(0).shape.dims[DIM2]));
outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0];
outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1];
outTensorDesc.at(0).shape.dims[DIM2] = inTensorDesc.at(0).shape.dims[DIM2];
}
else if (inTensorDesc.at(0).shape.dimNum == DIM2)
{
LOG_INFO("[input0 dimNum = 2] CHECK " + opName_ + " input shape: [input0] " +
std::to_string(inTensorDesc.at(0).shape.dims[DIM0]) + ", " +
std::to_string(inTensorDesc.at(0).shape.dims[DIM1]));
outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0];
outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1];
}
else
{
LOG_ERROR(opName_ + " invalid dimNum = " + std::to_string(inTensorDesc.at(0).shape.dimNum));
}
LOG_INFO(opName_ + " InferShape end");
return atb::NO_ERROR;
}
GetInputNum接口与GetOutputNum接口
这两个接口的主要功能分别为返回算子所需的输入tensor个数及算子所需的输出tensor个数,与InferShape类似,也只与算子本身特性及接口相关。
uint32_t GeluOperation::GetInputNum() const
{
return 1; // gelu入参个数
}
uint32_t GeluOperation::GetOutputNum() const
{
return 1; // gelu出参个数
}
父主题: ATB使用示例