昇腾社区首页
中文
注册

EmbeddingShardingPlanner.collective_plan(TorchRec

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

搜索最合适的分表计划。

函数原型

1
2
3
4
5
def collective_plan(
 module: nn.Module,
 sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
 pg: Optional[dist.ProcessGroup] = None,
)

参数说明

参数名

类型

可选/必选

说明

module

nn.Module

必选

包含HashEmbeddingBagCollection的module列表。

sharders

List[ModuleSharder[nn.Module]]

必选

Sharder的列表。

仅支持传入get_default_hybrid_sharders()的结果。

pg

dist.ProcessGroup

必选

传入dist.GroupMember.WORLD。

使用示例

1
2
3
from torchrec.distributed.planner import EmbeddingShardingPlanner
planner = EmbeddingShardingPlanner(XXX)
plan = planner.collective_plan(test_model, hybrid_sharders, dist.GroupMember.WORLD)

参考资源

接口调用流程及示例可参见迁移与训练