WhereOperation
产品支持情况
硬件型号 |
是否支持 |
特殊说明 |
---|---|---|
√ |
- |
|
√ |
- |
|
√ |
- |
|
√ |
不支持bf16数据类型。 |
|
√ |
- |
功能说明
三目运算。
输入张量为cond, x, y,输出张量 z = cond ? x : y。
- 输入:
cond(条件张量):定义了每个元素的选择条件。
x:当条件为 1时,选择的第一个输入张量。
y:当条件为0时,选择的第二个输入张量。
- 输出:输出张量的每个元素根据对应位置的条件从x或y中选取,返回一个和x、y形状相同的张量。 z = (cond == 1 ? x : y)
算子上下文
定义
1 2 3 | struct WhereParam { uint8_t rsv[8] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
rsv[8] |
uint8_t |
{0} |
预留参数。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
cond |
[dim_0, dim_1, ..., dim_n] |
int8 |
ND |
输出tensor1,条件变量。 |
x |
[x_dim_0, x_dim_1, ..., x_dim_n] |
float16 |
ND |
输入tensor2。 |
y |
[y_dim_0, y_dim_1, ..., y_dim_n] |
float16 |
ND |
输入tensor3。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
z |
[z_dim_0, z_dim_1, ..., z_dim_n] |
float16 |
ND |
输出tensor |
约束说明
输入cond的元素只能是0或者1。输出z的维度为输入x与y广播后的结果。要求cond, x, y必须是可广播的。
接口调用示例
输入:
cond = [[1, 0], [0, 1]] x = [[1, 2], [3, 4]] y = [[10, 20], [30, 40]]
输出:
z = [[ 1, 20], [30, 4]]