昇腾社区首页
中文
注册

WhereOperation

产品支持情况

硬件型号

是否支持

特殊说明

Atlas A3 推理系列产品/Atlas A3 训练系列产品

-

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

-

Atlas 训练系列产品

-

Atlas 推理系列产品

不支持bf16数据类型。

Atlas 200I/500 A2 推理产品

-

功能说明

三目运算。

输入张量为cond, x, y,输出张量 z = cond ? x : y。

  • 输入

    cond(条件张量):定义了每个元素的选择条件。

    x:当条件为 1时,选择的第一个输入张量。

    y:当条件为0时,选择的第二个输入张量。

  • 输出:输出张量的每个元素根据对应位置的条件从x或y中选取,返回一个和x、y形状相同的张量。 z = (cond == 1 ? x : y)

算子上下文

定义

1
2
3
struct WhereParam {
    uint8_t rsv[8] = {0};
};

参数列表

成员名称

类型

默认值

描述

rsv[8]

uint8_t

{0}

预留参数。

输入

参数

维度

数据类型

格式

描述

cond

[dim_0, dim_1, ..., dim_n]

int8

ND

输出tensor1,条件变量。

x

[x_dim_0, x_dim_1, ..., x_dim_n]

float16

ND

输入tensor2。

y

[y_dim_0, y_dim_1, ..., y_dim_n]

float16

ND

输入tensor3。

输出

参数

维度

数据类型

格式

描述

z

[z_dim_0, z_dim_1, ..., z_dim_n]

float16

ND

输出tensor

约束说明

输入cond的元素只能是0或者1。输出z的维度为输入x与y广播后的结果。要求cond, x, y必须是可广播的。

接口调用示例

输入:

cond = [[1, 0], 
              [0, 1]]
x = [[1, 2],
       [3, 4]]
y = [[10, 20],
       [30, 40]]

输出:

z = [[ 1, 20],
       [30, 4]]