下载
中文
注册

数据类型转换

方法介绍

由于float16数据类型本身存在精度限制,详情请参见Link,所以在进行算子实现时,可将入参中float16类型的数据转换成float32类型进行计算,用高精度数据进行中间计算来提升精度,从而提升最终结果的精度。当中间计算过程较为复杂时,此方法精度提升效果比较明显。

示例:实现res=x * y。

在算子计算之前,若输入数据x、y为fp16的数据类型,则将其转换为fp32类型,计算完成后,将结果数据res转换成原始数据类型fp16,算子代码示例如下所示:

dtype = data_x.dtype
if dtype == "float16":
    data_x = tbe.dsl.cast_to(data_x, "float32")
    data_y = tbe.dsl.cast_to(data_y, "float32")
    res = tbe.dsl.vmul(data_x, data_y)
if dtype == "float16":
    res = tbe.dsl.cast_to(res, "float16")

同时,在进行精度验证的时候,对比数据也要转换成fp32计算。

注意事项

进行数据类型的转换需要注意以下两点:
  • 在最后输出结果时要将数据类型转换成原数据类型。
  • float16转成float32计算会导致运行性能降低,因此如果使用float16的数据类型进行计算的精度在可允许范围内,尽量不要转换数据类型。