collective_plan

The APIs in this class are open-source APIs of the TorchRec and are not external APIs of the Rec SDK Torch. This section describes the parameter ranges supported by the TorchRec APIs called when the Rec SDK Torch is used.

Function

Searches for the most appropriate table plan.

Prototype

1
2
3
4
5
def collective_plan(
 module: nn.Module,
 sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
 pg: Optional[dist.ProcessGroup] = None,
)

Parameters

Parameter

Data Type

Mandatory/Optional

Description

module

nn.Module

Optional

When NPU devices are used, the module list that contains HashEmbeddingBagCollection must be passed.

sharders

List[ModuleSharder[nn.Module]]

Optional

List of sharders.

When NPU devices are used, only the result of get_default_hybrid_sharders() can be passed.

pg

dist.ProcessGroup

Optional

When NPU devices are used, dist.GroupMember.WORLD is passed.

Sample

1
2
3
from torchrec.distributed.planner import EmbeddingShardingPlanner
planner = EmbeddingShardingPlanner(XXX)
plan = planner.collective_plan(test_model, hybrid_sharders, dist.GroupMember.WORLD)

See Also

For details about the API call sequence and example, see Porting and Training.