aclnn_operation_base.cpp
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | #include "aclnn/aclnn_operation_base.h" #include "utils/log.h" AclnnBaseOperation::AclnnBaseOperation(const std::string &opName) : opName_(opName) {} AclnnBaseOperation::~AclnnBaseOperation() { aclExecutor_ = nullptr; } std::string AclnnBaseOperation::GetName() const { return opName_; } atb::Status AclnnBaseOperation::Setup( const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context) { LOG_INFO(opName_ + " setup start"); // 调用子类,创建输入输出tensor,并存入VariantPack int ret = CreateAclnnVariantPack(variantPack); if (ret != 0) { LOG_ERROR(opName_ + " call CreateAclnnVariantPack fail, error: " + std::to_string(ret)); return atb::ERROR_INVALID_PARAM; } // 调用子类,获取Executor和Workspace ret = SetAclnnWorkspaceExecutor(); if (ret != 0) { LOG_ERROR( opName_ + " call CreateAclnnVaSetAclnnWorkspaceExecutorriantPack fail, error: " + std::to_string(ret)); return atb::ERROR_INVALID_PARAM; } // 返回计算出的workspaceSize workspaceSize = workspaceSize_; LOG_INFO(opName_ + " setup end"); return ret; } atb::Status AclnnBaseOperation::Execute( const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, atb::Context *context) { LOG_INFO(opName_ + " execute start"); if (!context) { LOG_ERROR(opName_ + " execute fail, context param is null"); return atb::ERROR_INVALID_PARAM; } aclrtStream stream = context->GetExecuteStream(); if (!stream) { LOG_ERROR(opName_ + " execute fail, execute stream in context is null"); return atb::ERROR_INVALID_PARAM; } // 更新数据传入的地址 int ret = UpdateAclnnVariantPack(variantPack); if (ret != 0) { LOG_ERROR(opName_ + " call UpdateAclnnVariantPack fail, error: " + std::to_string(ret)); return atb::ERROR_CANN_ERROR; } LOG_INFO("Input workspaceSize " + std::to_string(workspaceSize) + " localCache workspaceSize " + std::to_string(workspaceSize_)); ret = ExecuteAclnnOp(workspace, stream); // 调用aclnn接口 if (ret != 0) { LOG_ERROR(opName_ + " call ExecuteAclnnOp fail, error: " + std::to_string(ret)); return atb::ERROR_CANN_ERROR; } LOG_INFO(opName_ + " execute start"); return ret; } atb::Status AclnnBaseOperation::UpdateAclnnVariantPack(const atb::VariantPack &variantPack) { // 更新inTensor的device地址 for (size_t i = 0; i < aclInTensors_.size(); ++i) { int ret = -1; if (!aclInTensors_[i]->needUpdateTensorDataPtr) { continue; } aclInTensors_[i]->atbTensor = variantPack.inTensors.at(i); ret = aclSetInputTensorAddr(aclExecutor_, aclInTensors_[i]->tensorIdx, aclInTensors_[i]->tensor, aclInTensors_[i]->atbTensor.deviceData); if (ret != 0) { LOG_ERROR( "inTensor " + std::to_string(i) + " call UpdateAclTensorDataPtr fail, error: " + std::to_string(ret)); return atb::ERROR_CANN_ERROR; } } // 更新outTensor的device地址 for (size_t i = 0; i < aclOutTensors_.size(); ++i) { int ret = -1; if (!aclOutTensors_[i]->needUpdateTensorDataPtr) { continue; } aclOutTensors_[i]->atbTensor = variantPack.outTensors.at(i); ret = aclSetOutputTensorAddr(aclExecutor_, aclOutTensors_[i]->tensorIdx, aclOutTensors_[i]->tensor, aclOutTensors_[i]->atbTensor.deviceData); if (ret != 0) { LOG_ERROR( "outTensor " + std::to_string(i) + " call UpdateAclTensorDataPtr fail, error: " + std::to_string(ret)); return atb::ERROR_CANN_ERROR; } } return atb::NO_ERROR; } |
父主题: 用例源码