昇腾社区首页
中文
注册
开发者
下载

npu_tagged_event_wait

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

Atlas A3 训练系列产品/Atlas A3 推理系列产品

功能说明

与torch.npu.Event.wait(torch.cuda.Event.wait的NPU形式,参见《PyTorch 原生API支持度》中的“torch.cuda)方法类似,用于让当前流等待指定的事件完成。

当算子下发时,系统会获取用户设置的所属流标签信息,并在该流上下发一个wait任务。该任务会使调用它的流暂停执行,直到指定的事件被记录(即record完成)。

注意wait与record需匹配使用,在Device上执行时,wait及之后的算子需要等待与之匹配的record执行完成后才能执行。

函数原型

def npu_tagged_event_wait(event)

参数说明

参数

输入/输出

说明

是否必选

event

输入

通过npu_create_tagged_event接口创建出来的event。

返回值说明

约束说明

  • 本接口只在reduce-overhead模式下生效,其他模式不建议使用。
  • 其他约束与torch.npu.Event.wait保持一致,此处不再赘述。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch, os
import torch_npu
import torchair
from torchair.configs.compiler_config import CompilerConfig
from torchair.core.utils import logger

# 创建一个tag标识为"66"的event对象
GLOBAL_EVENT = torchair.ops.npu_create_tagged_event(tag="66")

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, in1, in2, in3, in4):
        global GLOBAL_EVENT
        add_result = torch.add(in1, in2)
        # 插入一个event_record用于device上两条流之间的同步,对于GLOBAL_EVENT的wait后的任务需要等record执行完毕才能执行
        torchair.ops.npu_tagged_event_record(GLOBAL_EVENT)
        with torchair.scope.npu_stream_switch('1'):
            # torch.mm算子(mm_result)等待torch.add算子(add_result)执行完再执行
            torchair.ops.npu_tagged_event_wait(GLOBAL_EVENT)
            mm_result = torch.mm(in3, in4)
        mm1 = torch.mm(in3, in4)
        add2 = torch.add(in3, in4)
        return add_result, mm_result, mm1, add2
model = Model()
config = CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)

# 调用compile编译
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
in1 = torch.randn(1000, 1000, dtype = torch.float16).npu()
in2 = torch.randn(1000, 1000, dtype = torch.float16).npu()
in3 = torch.randn(1000, 1000, dtype = torch.float16).npu()
in4 = torch.randn(1000, 1000, dtype = torch.float16).npu()
result = model(in1, in2, in3, in4)
print(f"Result:\n{result}\n")