昇腾社区首页
中文
注册

定义数据类型转换算子(Cast)

通过算子原型构建Graph时,要求前后算子的dtype必须一致,上一个算子的输出dtype如果和下一层算子的输入dtype不匹配时需要插入Cast算子。

例如下面示例中,addn算子要求输入float32,但是greater算子的输出是bool类型,在数据类型发生变换的情况下,需要通过插入cast算子进行数据类型转换。

Cast算子原型定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
REG_OP(Cast)
    .INPUT(x, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8,
                          DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64,
                          DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32, DT_BF16, DT_UINT1}))
    .OUTPUT(y, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8,
                           DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64,
                           DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32,
                           DT_BF16, DT_COMPLEX32}))
    .REQUIRED_ATTR(dst_type, Int)
    .OP_END_FACTORY_REG(Cast)

Cast算子原型定义可以看到,有一个必选属性dst_type,表示转换后的数据类型,设置为0表示转换后的数据类型为float32,值和数据类型对应关系请参见DataType

1
2
3
4
5
6
7
auto greater = op::Greater("greater").set_input_x1(const1).set_input_x2(const2);
auto cast = op::Cast("cast").set_input_x(greater)
                                .set_attr_dst_type(0);   
auto addn = op::AddN("addn").create_dynamic_input_x(3)
	                        .set_dynamic_input_x(0,cast)
	                        .set_dynamic_input_x(1,data).set_dynamic_input_x(2,data)
	                        .set_attr_N(3);