昇腾社区首页
EN
注册

定义简单计算算子(SoftmaxV2)

下面以一个简单的SoftmaxV2为例,介绍如何进行算子定义。

SoftmaxV2算子原型定义:

1
2
3
4
5
REG_OP(SoftmaxV2)
    .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT16, DT_FLOAT}))
    .OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT16, DT_FLOAT}))
    .ATTR(axes, ListInt, {-1})
    .OP_END_FACTORY_REG(SoftmaxV2)

从SoftmaxV2算子原型可以看到,SoftmaxV2算子有一个必选输入,输入名称为x。创建SoftmaxV2算子实例:

1
2
auto softmax = op::SoftmaxV2("Softmax")     //创建算子实例,传算子名称(例如Softmax)作为入参
    .set_input_x(matmul2);                //设置算子输入为matmul2