collective_plan
The APIs in this class are open-source APIs of the TorchRec and are not external APIs of the Rec SDK Torch. This section describes the parameter ranges supported by the TorchRec APIs called when the Rec SDK Torch is used.
Function
Searches for the most appropriate table plan.
Prototype
1 2 3 4 5 | def collective_plan( module: nn.Module, sharders: Optional[List[ModuleSharder[nn.Module]]] = None, pg: Optional[dist.ProcessGroup] = None, ) |
Parameters
Parameter |
Data Type |
Mandatory/Optional |
Description |
|---|---|---|---|
module |
nn.Module |
Optional |
When NPU devices are used, the module list that contains HashEmbeddingBagCollection must be passed. |
sharders |
List[ModuleSharder[nn.Module]] |
Optional |
List of sharders. When NPU devices are used, only the result of get_default_hybrid_sharders() can be passed. |
pg |
dist.ProcessGroup |
Optional |
When NPU devices are used, dist.GroupMember.WORLD is passed. |
Sample
1 2 3 | from torchrec.distributed.planner import EmbeddingShardingPlanner planner = EmbeddingShardingPlanner(XXX) plan = planner.collective_plan(test_model, hybrid_sharders, dist.GroupMember.WORLD) |
See Also
For details about the API call sequence and example, see Porting and Training.