本节将主要围绕思路中的第二点对Aclnn算子接入ATB Plugin进行详细介绍。
接口名 |
自定义Operation类是否需要实现 |
---|---|
Setup |
是 |
Execute |
是 |
InferShape |
是 |
GetInputNum |
是 |
GetOutputNum |
是 |
GetName |
是 |
Setup接口的主要目的有两个:
atb::Status AclnnBaseOperation::Setup( const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context) { LOG_INFO(opName_ + " setup start"); // 调用子类,创建输入输出tensor,并存入VariantPack int ret = CreateAclnnVariantPack(variantPack); if (ret != 0) { LOG_ERROR(opName_ + " call CreateAclnnVariantPack fail, error: " + std::to_string(ret)); return atb::ERROR_INVALID_PARAM; } // 调用子类,子类中调用aclnn接口,获取Executor和Workspace ret = SetAclnnWorkspaceExecutor(); if (ret != 0) { LOG_ERROR( opName_ + " call CreateAclnnVaSetAclnnWorkspaceExecutorriantPack fail, error: " + std::to_string(ret)); return atb::ERROR_INVALID_PARAM; } // 返回计算出的workspaceSize workspaceSize = workspaceSize_; LOG_INFO(opName_ + " setup end"); return ret; }
Execute接口的主要功能是根据用户给到的显存地址,去进行对应算子任务的下发。在Aclnn算子接入ATB Plugin的流程中,这一步主要负责设置调用Aclnn的接口来进行输入设置与算子下发。
atb::Status AclnnBaseOperation::Execute( const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, atb::Context *context) { LOG_INFO(opName_ + " execute start"); if (!context) { LOG_ERROR(opName_ + " execute fail, context param is null"); return atb::ERROR_INVALID_PARAM; } // 获取执行stream aclrtStream stream = context->GetExecuteStream(); if (!stream) { LOG_ERROR(opName_ + " execute fail, execute stream in context is null"); return atb::ERROR_INVALID_PARAM; } // 更新数据传入的地址 int ret = UpdateAclnnVariantPack(variantPack); if (ret != 0) { LOG_ERROR(opName_ + " call UpdateAclnnVariantPack fail, error: " + std::to_string(ret)); return atb::ERROR_CANN_ERROR; } LOG_INFO("Input workspaceSize " + std::to_string(workspaceSize) + " localCache workspaceSize " + std::to_string(workspaceSize_)); ret = ExecuteAclnnOp(workspace, stream); // 调用aclnn接口 if (ret != 0) { LOG_ERROR(opName_ + " call ExecuteAclnnOp fail, error: " + std::to_string(ret)); return atb::ERROR_CANN_ERROR; } LOG_INFO(opName_ + " execute start"); return ret; }
该接口的主要功能为Shape推导,根据输入tensor的Shape及算子参数推导出输出tensor的Shape。该函数实现只与算子本身的特性及接口相关。
atb::Status GeluOperation::InferShape( const atb::SVector<atb::TensorDesc> &inTensorDesc, atb::SVector<atb::TensorDesc> &outTensorDesc) const { LOG_INFO(opName_ + " InferShape start"); outTensorDesc.at(0).format = inTensorDesc.at(0).format; outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype; outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum; if (inTensorDesc.at(0).shape.dimNum == DIM3) { LOG_INFO("[input0 dimNum = 3] CHECK " + opName_ + " input shape: [input0] " + std::to_string(inTensorDesc.at(0).shape.dims[DIM0]) + ", " + std::to_string(inTensorDesc.at(0).shape.dims[DIM1]) + ", " + std::to_string(inTensorDesc.at(0).shape.dims[DIM2])); outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0]; outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1]; outTensorDesc.at(0).shape.dims[DIM2] = inTensorDesc.at(0).shape.dims[DIM2]; } else if (inTensorDesc.at(0).shape.dimNum == DIM2) { LOG_INFO("[input0 dimNum = 2] CHECK " + opName_ + " input shape: [input0] " + std::to_string(inTensorDesc.at(0).shape.dims[DIM0]) + ", " + std::to_string(inTensorDesc.at(0).shape.dims[DIM1])); outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0]; outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1]; } else { LOG_ERROR(opName_ + " invalid dimNum = " + std::to_string(inTensorDesc.at(0).shape.dimNum)); } LOG_INFO(opName_ + " InferShape end"); return atb::NO_ERROR; }
这两个接口的主要功能分别为返回算子所需的输入tensor个数及算子所需的输出tensor个数,与InferShape类似,也只与算子本身特性及接口相关。
uint32_t GeluOperation::GetInputNum() const { return 1; // gelu入参个数 } uint32_t GeluOperation::GetOutputNum() const { return 1; // gelu出参个数 }