完整代码示例
迁移到昇腾AI处理器的推理脚本:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import npu_device from npu_device.compat.v1.npu_init import * npu_device.compat.enable_v1() from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig import tensorflow as tf import numpy as np import os import time import argparse # np.random.seed(10) def load_graph(frozen_graph): with tf.io.gfile.GFile(frozen_graph,"rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def,name="") return graph def NetworkRun(modelPath,inputPath,outputPath): graph = load_graph(modelPath) input_nodes = graph.get_tensor_by_name('Input:0') output_nodes = graph.get_tensor_by_name('Identity:0') #适配npu config_proto = tf.compat.v1.ConfigProto() custom_op = config_proto.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") config_proto.graph_options.rewrite_options.remapping = RewriterConfig.OFF tf_config = npu_config_proto(config_proto=config_proto) with tf.compat.v1.Session(config=tf_config,graph=graph) as sess: files = os.listdir(inputPath) files.sort() for file in files: if file.endswith(".bin"): input_img = np.fromfile(inputPath+"/"+file,dtype="float32").reshape(1,224,224,3) t0 = time.time() out = sess.run(output_nodes, feed_dict= {input_nodes: input_img,}) print('out---',out) t1 = time.time() out.tofile(outputPath+"/"+"cpu_out_"+file) print("{}, Inference time: {:.3f} ms".format(file,(t1-t0)*1000)) if __name__=='__main__': parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="./resnet50_tf2.pb") parser.add_argument("--input", type=str, default="./input_bin/") parser.add_argument("--output", type=str, default="./npu_output/") args = parser.parse_args() if not os.path.isdir(args.output): os.mkdir(args.output) NetworkRun(args.model,args.input,args.output) |
父主题: 在线推理