自定义Operation开发

进行迁移分析后,大部分算子功能可以由加速库提供的原生Operation或者原生Operation组图完成。如果一些功能原生Operation无法支持,用户可以通过PluginOperation机制,开发自定义的PluginOperation,满足对应的功能。最后再使用原生Operation和片定义PluginOperation共同完成整个模型的计算。

Operation类原型,用户通过继承Operation类,实现自定义的PluginOperation。

  1. 首先定义自定义PluginOperation的参数。

    struct XxOpParam{
        int    paraA;
        int    paraB;
    };

  2. 定义XxPluginOperation类,继承Operation。

    class Operation {
    public:
        Operation() = default;
        virtual ~Operation() = default;
        virtual Status InferShape(const SVector<TensorDesc> &inTensorDescs, SVector<TensorDesc> &outTensorDescs) const = 0;
        virtual uint32_t GetInputNum() const = 0;
        virtual uint32_t GetOutputNum() const = 0;
        virtual Status Setup(const VariantPack &variantPack, uint64_t &workspaceSize) = 0;
        virtual Status Execute(const VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, const aclrtStream stream) = 0;
    };
    重载Operation的各个接口,接口说明如下。
    • InferShape接口根据InTensorDesc信息,输出OutTensorDesc信息
    • GetInputNum/GetOutputNum获取输入输出tensor个数。
    • Setup:根据InTensor/outTensorDesc信息,计算需要的workspace
    • Execute:Operation算子执行