组图示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import torch_atb def graph_build(): builder = torch_atb.Builder("Graph") x = builder.add_input("x") y = builder.add_input("y") elewise_add = torch_atb.ElewiseParam() elewise_add.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD layer1 = builder.add_node([x, y], elewise_add) add_out = layer1.get_output(0) z = builder.add_input("z") builder.reshape(add_out, lambda shape: [1, shape[0] * shape[1]], "add_out_") elewise_mul = torch_atb.ElewiseParam() elewise_mul.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_MUL layer2 = builder.add_node(["add_out_", z], elewise_mul) builder.mark_output(layer2.get_output(0)) Graph = builder.build() if __name__ == "__main__": graph_build() |
支持将完成创建的计算图添加作当前计算图子节点。
使用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import torch_atb import torch s = 128 # Sequence Length h = 16 # Number of Heads d_k = 64 # Head Dimension d_v = 64 # Value Dimension (vHiddenSize) output_dim = 64 output_dim_1 = 128 def single_graph_build(n): print(f"------------ single graph {n} build begin ------------") graph = torch_atb.Builder("Graph") query = graph.add_input("query") key = graph.add_input("key") value = graph.add_input("value") seqLen = graph.add_input("seqLen") self_attention_param = torch_atb.SelfAttentionParam() self_attention_param.head_num = 16 self_attention_param.kv_head_num = 16 self_attention_param.calc_type = torch_atb.SelfAttentionParam.CalcType.PA_ENCODER # float16: query, key, value, # int32: seqLen # -> float16 (s, 16, d_k) self_attention = graph.add_node([query, key, value, seqLen], self_attention_param) self_attention_out = self_attention.get_output(0) input_0 = graph.add_input("input_0") elewise_add_param = torch_atb.ElewiseParam() elewise_add_param.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD elewise_add_0 = graph.add_node([self_attention_out, input_0], elewise_add_param) elewise_add_0_out = elewise_add_0.get_output(0) gamma = graph.add_input("gamma") # weight in layernorm, (Hadamard product) beta = graph.add_input("beta") # bias in layernorm layernorm_param = torch_atb.LayerNormParam() layernorm_param.layer_type = torch_atb.LayerNormParam.LayerNormType.LAYER_NORM_NORM layernorm_param.norm_param.begin_norm_axis = 0 layernorm_param.norm_param.begin_params_axis = 0 # x, gamma, beta, float16 -> float16 layernorm_0 = graph.add_node([elewise_add_0_out, gamma, beta], layernorm_param) layernorm_0_out = layernorm_0.get_output(0) weight_0 = graph.add_input("weight_0") # weight in linear bias_0 = graph.add_input("bias_0") # bias in linear linear_param = torch_atb.LinearParam() # x, weight, bias, float 16 -> float16 linear_0 = graph.add_node([layernorm_0_out, weight_0, bias_0], linear_param) linear_0_out = linear_0.get_output(0) elewise_tanh_param = torch_atb.ElewiseParam() elewise_tanh_param.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_TANH elewise_tanh = graph.add_node([linear_0_out], elewise_tanh_param) elewise_tanh_out = elewise_tanh.get_output(0) weight_1 = graph.add_input("weight_1") bias_1 = graph.add_input("bias_1") # x, weight, bias, float 16 -> float16 linear_1 = graph.add_node([elewise_tanh_out, weight_1, bias_1], linear_param) linear_1_out = linear_1.get_output(0) graph.mark_output(linear_1_out) Graph = graph.build() print(f"----------- single graph {n} build success -----------") return Graph def get_inputs(): torch.manual_seed(233) print("------------ generate inputs begin ------------") query = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu() key = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu() value = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu() seqLen = (torch.tensor([s], dtype = torch.int32)) # (s, 16, d_k) == (128, ,16 , 64) input_0 = (torch.randn((16, d_k), dtype=torch.float16)).npu() gamma = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu() beta = (torch.zeros((s, 16, d_k), dtype=torch.float16)).npu() # (s, 16, d_k) == (128, 16, 64) weight_0 = (torch.randn((output_dim_1, output_dim), dtype=torch.float16)).npu() bias_0 = (torch.randn((output_dim_1,), dtype=torch.float16)).npu() # (s, 16, output_dim1) == (128, 16, 128) weight_1 = (torch.randn((output_dim_1, output_dim_1), dtype=torch.float16)).npu() bias_1 = (torch.randn((output_dim_1,), dtype=torch.float16)).npu() # (s, 16, output_dim1) == (128, 16, 128) inputs = [query, key, value, seqLen, input_0, gamma, beta, weight_0, bias_0, weight_1, bias_1] print("------------ generate inputs success ------------") return inputs def run(): Graph_0 = single_graph_build(0) print("------------ bigger graph build begin ------------") bigger_graph = torch_atb.Builder("BiggerGraph") query = bigger_graph.add_input("query") key = bigger_graph.add_input("key") value = bigger_graph.add_input("value") seqLen = bigger_graph.add_input("seqLen") input_0 = bigger_graph.add_input("input_0") gamma = bigger_graph.add_input("gamma") beta = bigger_graph.add_input("beta") weight_0 = bigger_graph.add_input("weight_0") bias_0 = bigger_graph.add_input("bias_0") weight_1 = bigger_graph.add_input("weight_1") bias_1 = bigger_graph.add_input("bias_1") node_graph0 = bigger_graph.add_node([query, key, value, seqLen, input_0, gamma, beta, weight_0, bias_0, weight_1, bias_1], Graph_0) node_graph0_out = node_graph0.get_output(0) x = bigger_graph.add_input("x") elewise_add = torch_atb.ElewiseParam() elewise_add.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD layer = bigger_graph.add_node([x, node_graph0_out], elewise_add) bigger_graph.mark_output(layer.get_output(0)) BiggerGraph = bigger_graph.build() print("------------ bigger graph build success ------------") print(BiggerGraph.__repr__) print("------------ bigger graph forward begin ------------") inputs = get_inputs() x = (torch.ones((128, 16, 128), dtype=torch.float16)).npu() inputs.append(x) result = BiggerGraph.forward(inputs) print("------------ bigger graph forward success ------------") if __name__ == "__main__": run() |
父主题: 组图接口