自定义Operation开发
进行迁移分析后,大部分算子功能可以由加速库提供的原生Operation或者原生Operation组图完成。如果一些功能原生Operation无法支持,用户可以通过PluginOperation机制,开发自定义的PluginOperation,满足对应的功能。最后再使用原生Operation和片定义PluginOperation共同完成整个模型的计算。
Operation类原型,用户通过继承Operation类,实现自定义的PluginOperation。
- 首先定义自定义PluginOperation的参数。
struct XxOpParam{ int paraA; int paraB; };
- 定义XxPluginOperation类,继承Operation。
class Operation { public: Operation() = default; virtual ~Operation() = default; virtual Status InferShape(const SVector<TensorDesc> &inTensorDescs, SVector<TensorDesc> &outTensorDescs) const = 0; virtual uint32_t GetInputNum() const = 0; virtual uint32_t GetOutputNum() const = 0; virtual Status Setup(const VariantPack &variantPack, uint64_t &workspaceSize) = 0; virtual Status Execute(const VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, const aclrtStream stream) = 0; };
重载Operation的各个接口,接口说明如下。- InferShape接口:根据InTensorDesc信息,输出OutTensorDesc信息。
- GetInputNum/GetOutputNum:获取输入输出tensor个数。
- Setup:根据InTensor/outTensor的Desc信息,计算需要的workspace。
- Execute:Operation算子执行。
父主题: 模型开发