EmbeddingShardingPlanner.collective_plan(TorchRec)

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。
功能描述
搜索最合适的分表计划。
函数原型
1 2 3 4 5 | def collective_plan( module: nn.Module, sharders: Optional[List[ModuleSharder[nn.Module]]] = None, pg: Optional[dist.ProcessGroup] = None, ) |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
module |
nn.Module |
必选 |
包含HashEmbeddingBagCollection的module列表。 |
sharders |
List[ModuleSharder[nn.Module]] |
必选 |
Sharder的列表。 仅支持传入get_default_hybrid_sharders()的结果。 |
pg |
dist.ProcessGroup |
必选 |
传入dist.GroupMember.WORLD。 |
使用示例
1 2 3 | from torchrec.distributed.planner import EmbeddingShardingPlanner planner = EmbeddingShardingPlanner(XXX) plan = planner.collective_plan(test_model, hybrid_sharders, dist.GroupMember.WORLD) |
参考资源
接口调用流程及示例可参见迁移与训练。
父主题: 分表接口