vcmpsel
Description
Compares lhs with rhs element-wise based on operation. The operations specified by operation are selected from eq, ne, lt, gt, le, and ge, which indicate ==, !=, <, >, <=, and >=, respectively. If the expression is true, the value of slhs is returned. Otherwise, the value of srhs is returned.
The following describes the operations by using expressions. a indicates an element in lhs, b indicates an element in rhs, c indicates an element in slhs, d indicates an element in srhs, and res indicates an element of the result tensor. The expression is as follows:
- lt: res = c (a < b) or d (a >= b)
- gt: res = c (a > b) or d (a <= b)
- le: res = c (a <= b) or d (a > b)
- ge: res = c (a >= b) or d (a < b)
- eq: res = c (a == b) or d (a != b)
- ne: res = c (a != b) or d (a == b)
- If rhs is None, the elements in lhs are compared with the floating-point number 2.0.
- If slhs is None, the value of lhs is returned when the expression is true.
- If srhs is None and rhs is a tensor, the value of rhs is returned when the expression is not true.
If srhs is None and rhs is a scalar, the floating-point number 0.0 is returned when the expression is not true.
Prototype
vcmpsel(lhs,rhs=None,operation='lt', slhs=None, srhs=None)
Parameters
- lhs: a tvm.tensor for the left operand.
- rhs: a tvm.tensor or scalar for the right operand. Defaults to None.
- slhs: a tvm.tensor or scalar for the value returned when the comparison expression is true. Defaults to None.
- srhs: a tvm.tensor or scalar for the value returned when the comparison expression is not true. Defaults to None.
- operation: operation type selected from eq, ne, lt, gt, ge, or le. Defaults to lt.
- The data types of all parameters must be the same.
Atlas 200/300/500 Inference Product : supports float16.Atlas Training Series Product : supports float16 and float32.
Returns
res_tensor: a tvm.tensor for the result tensor
Restrictions
None
Applicability
Example
from tbe import tvm from tbe import dsl shape = (1024,1024) input_dtype = "float16" data1 = tvm.placeholder(shape, name="data1", dtype=input_dtype) data2 = tvm.placeholder(shape, name="data2", dtype=input_dtype) data3 = tvm.placeholder(shape, name="data3", dtype=input_dtype) data4 = tvm.placeholder(shape, name="data4", dtype=input_dtype) res = dsl.vcmpsel(data1, data2, 'gt', data3, data4)
Parent topic: Math Compute APIs