昇腾社区首页
中文
注册

组图示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import torch_atb

def graph_build():
    builder = torch_atb.Builder("Graph")

    x = builder.add_input("x")
    y = builder.add_input("y")
    elewise_add = torch_atb.ElewiseParam()
    elewise_add.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD
    layer1 = builder.add_node([x, y], elewise_add)
    add_out = layer1.get_output(0)
    z = builder.add_input("z")
    builder.reshape(add_out, lambda shape: [1, shape[0] * shape[1]], "add_out_")
    elewise_mul = torch_atb.ElewiseParam()
    elewise_mul.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_MUL
    layer2 = builder.add_node(["add_out_", z], elewise_mul)
    builder.mark_output(layer2.get_output(0))
    Graph = builder.build()

if __name__ == "__main__":
    graph_build()

支持将完成创建的计算图添加作当前计算图子节点。

使用示例

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch_atb
import torch

s = 128 # Sequence Length
h = 16 # Number of Heads
d_k = 64 # Head Dimension
d_v = 64 # Value Dimension (vHiddenSize)
output_dim = 64
output_dim_1 = 128

def single_graph_build(n):
    print(f"------------ single graph {n} build begin ------------")
    graph = torch_atb.Builder("Graph")

    query = graph.add_input("query")
    key = graph.add_input("key")
    value = graph.add_input("value")
    seqLen = graph.add_input("seqLen")
    self_attention_param = torch_atb.SelfAttentionParam()
    self_attention_param.head_num = 16
    self_attention_param.kv_head_num = 16
    self_attention_param.calc_type = torch_atb.SelfAttentionParam.CalcType.PA_ENCODER

    # float16: query, key, value,
    # int32: seqLen
    # -> float16 (s, 16, d_k)
    self_attention = graph.add_node([query, key, value, seqLen], self_attention_param)
    self_attention_out = self_attention.get_output(0)

    input_0 = graph.add_input("input_0")
    elewise_add_param = torch_atb.ElewiseParam()
    elewise_add_param.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD

    elewise_add_0 = graph.add_node([self_attention_out, input_0], elewise_add_param)
    elewise_add_0_out = elewise_add_0.get_output(0)

    gamma = graph.add_input("gamma") # weight in layernorm, (Hadamard product)
    beta = graph.add_input("beta") # bias in layernorm
    layernorm_param = torch_atb.LayerNormParam()
    layernorm_param.layer_type = torch_atb.LayerNormParam.LayerNormType.LAYER_NORM_NORM
    layernorm_param.norm_param.begin_norm_axis = 0
    layernorm_param.norm_param.begin_params_axis = 0

    # x, gamma, beta, float16 -> float16
    layernorm_0 = graph.add_node([elewise_add_0_out, gamma, beta], layernorm_param)
    layernorm_0_out = layernorm_0.get_output(0)

    weight_0 = graph.add_input("weight_0") # weight in linear
    bias_0 = graph.add_input("bias_0") # bias in linear
    linear_param = torch_atb.LinearParam()

    # x, weight, bias, float 16 -> float16
    linear_0 = graph.add_node([layernorm_0_out, weight_0, bias_0], linear_param) 
    linear_0_out = linear_0.get_output(0)

    elewise_tanh_param = torch_atb.ElewiseParam()
    elewise_tanh_param.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_TANH

    elewise_tanh = graph.add_node([linear_0_out], elewise_tanh_param)
    elewise_tanh_out = elewise_tanh.get_output(0)

    weight_1 = graph.add_input("weight_1")
    bias_1 = graph.add_input("bias_1")

    # x, weight, bias, float 16 -> float16
    linear_1 = graph.add_node([elewise_tanh_out, weight_1, bias_1], linear_param)
    linear_1_out = linear_1.get_output(0)

    graph.mark_output(linear_1_out)
    Graph = graph.build()
    print(f"----------- single graph {n} build success -----------")
    return Graph

def get_inputs():
    torch.manual_seed(233)

    print("------------ generate inputs begin ------------")
    query = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu()
    key = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu()
    value = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu()
    seqLen = (torch.tensor([s], dtype = torch.int32))
    # (s, 16, d_k) == (128, ,16 , 64)

    input_0 = (torch.randn((16, d_k), dtype=torch.float16)).npu()

    gamma = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu()
    beta = (torch.zeros((s, 16, d_k), dtype=torch.float16)).npu()
    # (s, 16, d_k) == (128, 16, 64)

    weight_0 = (torch.randn((output_dim_1, output_dim), dtype=torch.float16)).npu()
    bias_0 = (torch.randn((output_dim_1,), dtype=torch.float16)).npu()
    # (s, 16, output_dim1) == (128, 16, 128)

    weight_1 = (torch.randn((output_dim_1, output_dim_1), dtype=torch.float16)).npu()
    bias_1 = (torch.randn((output_dim_1,), dtype=torch.float16)).npu()
    # (s, 16, output_dim1) == (128, 16, 128)

    inputs = [query, key, value, seqLen, input_0, gamma, beta, weight_0, bias_0, weight_1, bias_1]
    print("------------ generate inputs success ------------")
    return inputs

def run():
    Graph_0 = single_graph_build(0)

    print("------------ bigger graph build begin ------------")
    bigger_graph = torch_atb.Builder("BiggerGraph")

    query = bigger_graph.add_input("query")
    key = bigger_graph.add_input("key")
    value = bigger_graph.add_input("value")
    seqLen = bigger_graph.add_input("seqLen")
    input_0 = bigger_graph.add_input("input_0")
    gamma = bigger_graph.add_input("gamma")
    beta = bigger_graph.add_input("beta")
    weight_0 = bigger_graph.add_input("weight_0")
    bias_0 = bigger_graph.add_input("bias_0")
    weight_1 = bigger_graph.add_input("weight_1")
    bias_1 = bigger_graph.add_input("bias_1")

    node_graph0 = bigger_graph.add_node([query, key, value, seqLen, input_0, gamma, beta, 
                                            weight_0, bias_0, weight_1, bias_1], Graph_0)
    node_graph0_out = node_graph0.get_output(0)

    x = bigger_graph.add_input("x")
    elewise_add = torch_atb.ElewiseParam()
    elewise_add.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD
    layer = bigger_graph.add_node([x, node_graph0_out], elewise_add)

    bigger_graph.mark_output(layer.get_output(0))

    BiggerGraph = bigger_graph.build()

    print("------------ bigger graph build success ------------")
    print(BiggerGraph.__repr__)
    print("------------ bigger graph forward begin ------------")
    inputs = get_inputs()
    x = (torch.ones((128, 16, 128), dtype=torch.float16)).npu()
    inputs.append(x)
    result = BiggerGraph.forward(inputs)
    print("------------ bigger graph forward success ------------")


if __name__ == "__main__":
    run()