EmbeddingShardingPlanner(TorchRec)
此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。
功能描述
创建分表计划器,用于搜索最合适的分表计划。
函数原型
1 2 | class EmbeddingShardingPlanner: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
topology |
Topology |
必选 |
参考Topology(TorchRec)的取值范围。 |
constraints |
Dict[str, ParameterConstraints] |
必选 |
参考的取值范围。 |
batch_size |
int |
可选 |
取值范围:[1, 1000000]。 |
enumerator |
torchrec.distributed.planner.types.Enumerator |
可选 |
仅支持默认值为None,不支持用户自定义。 |
storage_reservation |
torchrec.distributed.planner.types.StorageReservation |
可选 |
仅支持默认值为None,不支持用户自定义。 |
proposer |
torchrec.distributed.planner.types.Proposer |
可选 |
仅支持默认值为None,不支持用户自定义。 |
partitioner |
torchrec.distributed.planner.types.Partitioner |
可选 |
仅支持默认值为None,不支持用户自定义。 |
performance_model |
torchrec.distributed.planner.types.PerfModel |
可选 |
仅支持默认值为None,不支持用户自定义。 |
stats |
torchrec.distributed.planner.types.Stats |
可选 |
仅支持默认值为None,不支持用户自定义。 |
debug |
bool |
可选 |
仅支持默认值为True,不支持用户自定义。 |
callbacks |
List[Callable] |
可选 |
仅支持默认值为None,不支持用户自定义。 |
使用示例
1 2 3 4 5 | from torchrec.distributed.planner import EmbeddingShardingPlanner planner = EmbeddingShardingPlanner( topology=topology, constraints=constraints, ) |
参考资源
接口调用流程及示例可参见迁移与训练。