定义数据类型转换算子(Cast)
通过算子原型构建Graph时,要求前后算子的dtype必须一致,上一个算子的输出dtype如果和下一层算子的输入dtype不匹配时需要插入Cast算子。
例如下面示例中,addn算子要求输入float32,但是greater算子的输出是bool类型,在数据类型发生变换的情况下,需要通过插入cast算子进行数据类型转换。
Cast算子原型定义:
1 2 3 4 5 6 7 8 9 10 |
REG_OP(Cast) .INPUT(x, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8, DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32, DT_BF16, DT_UINT1})) .OUTPUT(y, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8, DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32, DT_BF16, DT_COMPLEX32})) .REQUIRED_ATTR(dst_type, Int) .OP_END_FACTORY_REG(Cast) |
Cast算子原型定义可以看到,有一个必选属性dst_type,表示转换后的数据类型,设置为0表示转换后的数据类型为float32,值和数据类型对应关系请参见DataType。
1 2 3 4 5 6 7 |
auto greater = op::Greater("greater").set_input_x1(const1).set_input_x2(const2); auto cast = op::Cast("cast").set_input_x(greater) .set_attr_dst_type(0); auto addn = op::AddN("addn").create_dynamic_input_x(3) .set_dynamic_input_x(0,cast) .set_dynamic_input_x(1,data).set_dynamic_input_x(2,data) .set_attr_N(3); |
父主题: 各类算子表达