WhereOperation
产品支持情况
硬件型号  | 
是否支持  | 
特殊说明  | 
|---|---|---|
√  | 
-  | 
|
√  | 
-  | 
|
√  | 
-  | 
|
√  | 
不支持bf16数据类型。  | 
|
√  | 
-  | 
功能说明
三目运算。
输入张量为cond, x, y,输出张量z。
- 输入:
cond(条件张量):定义了每个元素的选择条件。
x:当条件为 1时,选择的第一个输入张量。
y:当条件为0时,选择的第二个输入张量。
 - 输出:输出张量的每个元素根据对应位置的条件从x或y中选取,返回一个和x、y形状相同的张量。
z = (cond == 1 ? x : y)
 
算子上下文

定义
1 2 3  | struct WhereParam { uint8_t rsv[8] = {0}; };  | 
参数列表
成员名称  | 
类型  | 
默认值  | 
描述  | 
|---|---|---|---|
rsv[8]  | 
uint8_t  | 
{0}  | 
预留参数。  | 
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
cond  | 
[dim_0, dim_1, ..., dim_n]  | 
int8  | 
ND  | 
输出tensor1,条件变量。  | 
x  | 
[x_dim_0, x_dim_1, ..., x_dim_n]  | 
float16  | 
ND  | 
输入tensor2。  | 
y  | 
[y_dim_0, y_dim_1, ..., y_dim_n]  | 
float16  | 
ND  | 
输入tensor3。  | 
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
z  | 
[z_dim_0, z_dim_1, ..., z_dim_n]  | 
float16  | 
ND  | 
输出tensor  | 
约束说明
输入cond的元素只能是0或者1。输出z的维度为输入x与y广播后的结果。要求cond, x, y必须是可广播的。
接口调用示例
输入:
cond = [[1, 0], 
              [0, 1]]
x = [[1, 2],
       [3, 4]]
y = [[10, 20],
       [30, 40]]
输出:
z = [[ 1, 20],
       [30, 4]]