aclnn_operation_base.cpp

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include "aclnn/aclnn_operation_base.h"
#include "utils/log.h"

AclnnBaseOperation::AclnnBaseOperation(const std::string &opName) : opName_(opName)
{}

AclnnBaseOperation::~AclnnBaseOperation()
{
    aclExecutor_ = nullptr;
}

std::string AclnnBaseOperation::GetName() const
{
    return opName_;
}

atb::Status AclnnBaseOperation::Setup(
    const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context)
{
    LOG_INFO(opName_ + " setup start");

    // 调用子类,创建输入输出tensor,并存入VariantPack
    int ret = CreateAclnnVariantPack(variantPack);
    if (ret != 0) {
        LOG_ERROR(opName_ + " call CreateAclnnVariantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_INVALID_PARAM;
    }

    // 调用子类,获取Executor和Workspace
    ret = SetAclnnWorkspaceExecutor();
    if (ret != 0) {
        LOG_ERROR(
            opName_ + " call CreateAclnnVaSetAclnnWorkspaceExecutorriantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_INVALID_PARAM;
    }
    // 返回计算出的workspaceSize
    workspaceSize = workspaceSize_;
    LOG_INFO(opName_ + " setup end");
    return ret;
}

atb::Status AclnnBaseOperation::Execute(
    const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, atb::Context *context)
{
    LOG_INFO(opName_ + " execute start");
    if (!context) {
        LOG_ERROR(opName_ + " execute fail, context param is null");
        return atb::ERROR_INVALID_PARAM;
    }

    aclrtStream stream = context->GetExecuteStream();
    if (!stream) {
        LOG_ERROR(opName_ + " execute fail, execute stream in context is null");
        return atb::ERROR_INVALID_PARAM;
    }

    // 更新数据传入的地址
    int ret = UpdateAclnnVariantPack(variantPack);
    if (ret != 0) {
        LOG_ERROR(opName_ + " call UpdateAclnnVariantPack fail, error: " + std::to_string(ret));
        return atb::ERROR_CANN_ERROR;
    }

    LOG_INFO("Input workspaceSize " + std::to_string(workspaceSize) + " localCache workspaceSize " +
             std::to_string(workspaceSize_));
    ret = ExecuteAclnnOp(workspace, stream); // 调用aclnn接口
    if (ret != 0) {
        LOG_ERROR(opName_ + " call ExecuteAclnnOp fail, error: " + std::to_string(ret));
        return atb::ERROR_CANN_ERROR;
    }
    LOG_INFO(opName_ + " execute start");

    return ret;
}

atb::Status AclnnBaseOperation::UpdateAclnnVariantPack(const atb::VariantPack &variantPack)
{
    // 更新inTensor的device地址
    for (size_t i = 0; i < aclInTensors_.size(); ++i) {
        int ret = -1;
        if (!aclInTensors_[i]->needUpdateTensorDataPtr) {
            continue;
        }
        aclInTensors_[i]->atbTensor = variantPack.inTensors.at(i);
        ret = aclSetInputTensorAddr(aclExecutor_,
            aclInTensors_[i]->tensorIdx,
            aclInTensors_[i]->tensor,
            aclInTensors_[i]->atbTensor.deviceData);

        if (ret != 0) {
            LOG_ERROR(
                "inTensor " + std::to_string(i) + " call UpdateAclTensorDataPtr fail, error: " + std::to_string(ret));
            return atb::ERROR_CANN_ERROR;
        }
    }

    // 更新outTensor的device地址
    for (size_t i = 0; i < aclOutTensors_.size(); ++i) {
        int ret = -1;
        if (!aclOutTensors_[i]->needUpdateTensorDataPtr) {
            continue;
        }
        aclOutTensors_[i]->atbTensor = variantPack.outTensors.at(i);
        ret = aclSetOutputTensorAddr(aclExecutor_,
            aclOutTensors_[i]->tensorIdx,
            aclOutTensors_[i]->tensor,
            aclOutTensors_[i]->atbTensor.deviceData);

        if (ret != 0) {
            LOG_ERROR(
                "outTensor " + std::to_string(i) + " call UpdateAclTensorDataPtr fail, error: " + std::to_string(ret));
            return atb::ERROR_CANN_ERROR;
        }
    }
    return atb::NO_ERROR;
}