def keep_tensors_dtypes(graph, input_tensors)
指定哪些算子保持原有精度。
参数名 |
输入/输出 |
描述 |
---|---|---|
graph |
输入 |
从pb模型导入的图。 |
input_tensors |
输入 |
需要保持精度的算子名称。 |
无。
1 2 3 |
from npu_bridge.estimator.npu import util g=tf.Graph() util.keep_tensors_dtypes(g,("random_uniform_1/sub:0",)) |