下载
EN
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

完整代码示例

迁移到昇腾AI处理器的推理脚本:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import npu_device
from npu_device.compat.v1.npu_init import *
npu_device.compat.enable_v1()
from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig

import tensorflow as tf
import numpy as np
import os
import time
import argparse

# np.random.seed(10)

def load_graph(frozen_graph):
    with tf.io.gfile.GFile(frozen_graph,"rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,name="")
    return graph

def NetworkRun(modelPath,inputPath,outputPath):
    graph = load_graph(modelPath)
    input_nodes = graph.get_tensor_by_name('Input:0')
    output_nodes = graph.get_tensor_by_name('Identity:0')
    #适配npu
    config_proto = tf.compat.v1.ConfigProto()
    custom_op = config_proto.graph_options.rewrite_options.custom_optimizers.add()
    custom_op.name = "NpuOptimizer"
    custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision")
    config_proto.graph_options.rewrite_options.remapping = RewriterConfig.OFF
    tf_config = npu_config_proto(config_proto=config_proto)

    with tf.compat.v1.Session(config=tf_config,graph=graph) as sess:
        files = os.listdir(inputPath)
        files.sort()
        for file in files:
            if file.endswith(".bin"):
                input_img = np.fromfile(inputPath+"/"+file,dtype="float32").reshape(1,224,224,3)
                t0 = time.time()
                out = sess.run(output_nodes, feed_dict= {input_nodes: input_img,})
                print('out---',out)
                t1 = time.time()
                out.tofile(outputPath+"/"+"cpu_out_"+file)
                print("{}, Inference time: {:.3f} ms".format(file,(t1-t0)*1000))

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="./resnet50_tf2.pb")
    parser.add_argument("--input", type=str, default="./input_bin/")
    parser.add_argument("--output", type=str, default="./npu_output/")
    args = parser.parse_args()
    if not os.path.isdir(args.output):
        os.mkdir(args.output)
    NetworkRun(args.model,args.input,args.output)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词