keep_tensors_dtypes
产品支持情况
产品 |
是否支持 |
---|---|
√ |
|
√ |
|
x |
|
x |
|
x |
|
√ |
|
x |

针对
功能说明
指定哪些算子保持原有精度。
函数原型
1 | def keep_tensors_dtypes(graph, input_tensors)
|
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
graph |
输入 |
从pb模型导入的图。 |
input_tensors |
输入 |
需要保持精度的算子名称。 |
返回值
无。
使用约束
- 该接口仅适用于在线推理场景。
- 算子精度模式为保持原图精度(即precision_mode指定为must_keep_origin_dtype)时,该接口不生效。
调用示例
1 2 3 | from npu_bridge.estimator.npu import util
g=tf.Graph()
util.keep_tensors_dtypes(g,("random_uniform_1/sub:0",))
|