WhereOperation

功能

三目运算。

输入张量为cond, x, y,输出张量 z = cond ? x : y。

算子上下文

图1 WhereOperation算子上下文

算子功能实现

定义

struct WhereParam {};

输入

参数

维度

数据类型

格式

描述

cond

[dim_0,dim_1,... ,dim_n]

int8

ND

输出tensor1,条件变量。

x

[x_dim_0,x_dim_1,... ,x_dim_n]

float16

ND

输入tensor2。

y

[y_dim_0,y_dim_1,... ,y_dim_n]

float16

ND

输入tensor3。

输出

参数

维度

数据类型

格式

描述

z

[z_dim_0,z_dim_1,... ,z_dim_n]

float16

ND

输出tensor

规格约束

输入cond的元素只能是0或者1。输出z的维度为输入x与y广播后的结果。要求cond, x, y必须是可广播的。

接口调用示例

输入:

cond = [[1, 0], 
              [0, 1]]
x = [[1, 2],
       [3, 4]]
y = [[10, 20],
       [30, 40]]

输出:

z = [[ 1, 20],
       [30, 4]]