npu_tagged_event_wait
支持的型号
功能说明
与torch.npu.Event.wait(torch.cuda.Event.wait的NPU形式,参见《PyTorch 原生API支持度》中的“torch.cuda”)方法类似,用于让当前流等待指定的事件完成。
当算子下发时,系统会获取用户设置的所属流标签信息,并在该流上下发一个wait任务。该任务会使调用它的流暂停执行,直到指定的事件被记录(即record完成)。
注意wait与record需匹配使用,在Device上执行时,wait及之后的算子需要等待与之匹配的record执行完成后才能执行。
函数原型
def npu_tagged_event_wait(event)
参数说明
参数 |
输入/输出 |
说明 |
是否必选 |
|---|---|---|---|
event |
输入 |
通过npu_create_tagged_event接口创建出来的event。 |
是 |
返回值说明
无
约束说明
- 本接口只在reduce-overhead模式下生效,其他模式不建议使用。
- 其他约束与torch.npu.Event.wait保持一致,此处不再赘述。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import torch, os import torch_npu import torchair from torchair.configs.compiler_config import CompilerConfig from torchair.core.utils import logger # 创建一个tag标识为"66"的event对象 GLOBAL_EVENT = torchair.ops.npu_create_tagged_event(tag="66") class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, in1, in2, in3, in4): global GLOBAL_EVENT add_result = torch.add(in1, in2) # 插入一个event_record用于device上两条流之间的同步,对于GLOBAL_EVENT的wait后的任务需要等record执行完毕才能执行 torchair.ops.npu_tagged_event_record(GLOBAL_EVENT) with torchair.scope.npu_stream_switch('1'): # torch.mm算子(mm_result)等待torch.add算子(add_result)执行完再执行 torchair.ops.npu_tagged_event_wait(GLOBAL_EVENT) mm_result = torch.mm(in3, in4) mm1 = torch.mm(in3, in4) add2 = torch.add(in3, in4) return add_result, mm_result, mm1, add2 model = Model() config = CompilerConfig() config.mode = "reduce-overhead" npu_backend = torchair.get_npu_backend(compiler_config=config) # 调用compile编译 model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True) in1 = torch.randn(1000, 1000, dtype = torch.float16).npu() in2 = torch.randn(1000, 1000, dtype = torch.float16).npu() in3 = torch.randn(1000, 1000, dtype = torch.float16).npu() in4 = torch.randn(1000, 1000, dtype = torch.float16).npu() result = model(in1, in2, in3, in4) print(f"Result:\n{result}\n") |
父主题: torchair.ops