aclnn_gelu_operation.h

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#ifndef ACLNN_GELU_OPERATION_H
#define ACLNN_GELU_OPERATION_H

#include "aclnn/aclnn_operation_base.h"

struct AclnnGeluParam
{
    int64_t geluApproximate = -1; // gelu_v2计算的入参,指定高斯近似算法,0: "none", 1: "tanh" , -1: 不使用gelu_v2
};

class GeluOperation : public AclnnBaseOperation
{
public:
    GeluOperation(const std::string &name, AclnnGeluParam param);
    atb::Status InferShape(
        const atb::SVector<atb::TensorDesc> &inTensorDesc, atb::SVector<atb::TensorDesc> &outTensorDesc) const override;
    uint32_t GetInputNum() const override;
    uint32_t GetOutputNum() const override;

    atb::Status CreateAclnnVariantPack(const atb::VariantPack &variantPack) override;
    atb::Status SetAclnnWorkspaceExecutor() override;
    atb::Status ExecuteAclnnOp(uint8_t *workspace, aclrtStream &stream) override;

private:
    atb::Status CreateAclnnInTensor(const atb::VariantPack &variantPack);
    atb::Status CreateAclnnOutTensor(const atb::VariantPack &variantPack);
    std::shared_ptr<AclnnTensor> CreateAclnnTensor(atb::Tensor atbTensor, size_t tensorIdx);

    AclnnGeluParam param_;
};

#endif