定义动态多输入算子(AddN)
某些算子的输入个数不固定,为动态多输入算子,例如AddN,下面介绍如何定义这类算子。
AddN算子原型定义:
1 2 3 4 5 | REG_OP(AddN) .DYNAMIC_INPUT(x, TensorType({NumberType(), DT_VARIANT})) .OUTPUT(y, TensorType({NumberType(), DT_VARIANT})) .REQUIRED_ATTR(N, Int) .OP_END_FACTORY_REG(AddN) |
通过AddN算子原型定义可以看到,该算子为动态多输入算子,我们通过“create_dynamic_input_输入名称”创建动态输入,通过“set_dynamic_input_输入名称”设置动态输入。
1 2 3 4 5 6 | auto data = op::Data().set_attr_index(0); auto addn = op::AddN("addn") .create_dynamic_input_x(2) // 创建动态输入x,包括2个输入,并且把这两个输入作为算子最后的输入 .set_dynamic_input_x(0,data) // 设置第1个输入,0表示输入索引,默认从0开始,data表示输入value .set_dynamic_input_x(1,data) // 设置第2个输入,1表示输入索引,默认从0开始,data表示输入value .set_attr_N(2); // 设置属性N的值为2,表示该算子有2个输入 |
也可以通过“create_dynamic_input_byindex_输入名称”创建动态输入,但是和“create_dynamic_input_输入名称”不能同时使用,两者的区别是:“create_dynamic_input_输入名称”默认把创建的动态输入作为算子最后的输入,而“create_dynamic_input_byindex_输入名称”可以指定动态输入的索引位置,例如:
1 2 3 4 5 | auto addn = op::AddN("addn") .create_dynamic_input_byindex_x(2,0) // 创建动态输入x,包括2个输入,并且把这两个输入插入到索引0和索引1的位置,0表示动态输入索引的起始位置 .set_dynamic_input_x(0,data1) // 设置第1个输入,0表示输入索引,默认从0开始,data1表示输入value .set_dynamic_input_x(1,data2) // 设置第2个输入,1表示输入索引,默认从0开始,data2表示输入value .set_attr_N(2); // 设置属性N的值为2,表示该算子有2个输入 |
父主题: 各类算子表达