npu_weight_prefetch_scope
Description
Identifies the operators whose weight data will be prefetched into a buffer pool and specifies the ID and size of the buffer pool.
A prefetch buffer pool is an independent area of Ascend AI Processor memory. The size is determined before compilation, based on which prefetch tasks are controlled. If a buffer pool is full, prefetch tasks reuse the memory from the start address of the pool with timing control.
For an ultra-large model trained on a cluster, if weights are distributed to Ascend AI Processors, only 1/N (N indicates the number of Ascend AI Processors participating in training) weight data is stored on each device, reducing the memory footprint of the large model on each device. Before the compute operators are executed, the full weight data needs to be pulled to the local host. To avoid lack of memory, read-ahead weight data is stored in buffer pools.
Prototype
def npu_weight_prefetch_scope(buffer_pool_id=0, buffer_pool_size=536870912)
Options
Option |
Input/Output |
Description |
|---|---|---|
buffer_pool_id |
Input |
An int, indicating the ID of the buffer pool to enable. Defaults to 0. |
buffer_pool_size |
Input |
Size (bytes) of the specified buffer pool. Defaults to 536870912 (about 512 MB). |
Returns
None
Restrictions
- The prefetch buffer pool supports only prefetch operators with single input and single output.
- The sizes of buffer pools with the same ID must be the same.
- The buffer pool must be large enough for the largest prefetch operator, including its possible aligned and padded parts.
- The prefetch buffer pool is not supported for prefetch operators in a subgraph or control flow.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 | from npu_bridge.estimator.npu.npu_scope import npu_weight_prefetch_scope . . . with npu_weight_prefetch_scope(): # The output memory of AllGather uses the default buffer pool. global_weight1 = hcom.allgather(local_weight1) . . . with npu_weight_prefetch_scope(1, 268435456): # 256 MB: 256 x 1024 x 1024 # The output memory of AllGather uses the 256 MB buffer pool indexed 1. global_weight2 = hcom.allgather(local_weight2) |