aclnn_operation_base.h

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#ifndef ACLNN_OPERATION_BASE_H
#define ACLNN_OPERATION_BASE_H

#include <acl/acl.h>
#include <aclnn/acl_meta.h>
#include <atb/atb_infer.h>
#include <atb/types.h>
#include <atb/utils.h>
#include "atb/infer_op_params.h"

// 对atb::tensor的一层封装
struct AclnnTensor
{
public:
    atb::Tensor atbTensor; //
    aclTensor *tensor = nullptr;
    int tensorIdx = -1; // aclTensor在aclExecutor中的index
    bool needUpdateTensorDataPtr = false;
    atb::SVector<int64_t> strides = {};
};

// 保持与atb的算子的统一接口调用
class AclnnBaseOperation : public atb::Operation
{
public:
    explicit AclnnBaseOperation(const std::string &opName);
    ~AclnnBaseOperation() override;
    std::string GetName() const override;

    // 仿atb接口,获取workspace的大小
    atb::Status Setup(const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context) override;

    // 仿atb接口,算子执行
    atb::Status Execute(const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize,
                        atb::Context *context) override;

    // 创建输入aclnntensor
    virtual atb::Status CreateAclnnVariantPack(const atb::VariantPack &variantPack) = 0;

    // 计算workspace大小
    virtual atb::Status SetAclnnWorkspaceExecutor() = 0;

    // 执行Aclnn op
    virtual atb::Status ExecuteAclnnOp(uint8_t *workspace, aclrtStream &stream) = 0;

    // 更新aclnn输入和输出tensor的地址
    atb::Status UpdateAclnnVariantPack(const atb::VariantPack &variantPack);

    std::string opName_;
    aclOpExecutor *aclExecutor_ = nullptr;
    atb::SVector<std::shared_ptr<AclnnTensor>> aclInTensors_;
    atb::SVector<std::shared_ptr<AclnnTensor>> aclOutTensors_;
    uint64_t workspaceSize_;
    int workspaceBlockId_ = -1;
};

#endif