昇腾社区首页
中文
注册
开发者
下载

EmbeddingShardingPlanner(TorchRec

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

创建分表计划器,用于搜索最合适的分表计划。

函数原型

1
2
class EmbeddingShardingPlanner:
    def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

topology

Topology

必选

参考Topology(TorchRec)的取值范围。

constraints

Dict[str, ParameterConstraints]

必选

参考的取值范围。

batch_size

int

可选

取值范围:[1, 1000000]。

enumerator

torchrec.distributed.planner.types.Enumerator

可选

仅支持默认值为None,不支持用户自定义。

storage_reservation

torchrec.distributed.planner.types.StorageReservation

可选

仅支持默认值为None,不支持用户自定义。

proposer

torchrec.distributed.planner.types.Proposer

可选

仅支持默认值为None,不支持用户自定义。

partitioner

torchrec.distributed.planner.types.Partitioner

可选

仅支持默认值为None,不支持用户自定义。

performance_model

torchrec.distributed.planner.types.PerfModel

可选

仅支持默认值为None,不支持用户自定义。

stats

torchrec.distributed.planner.types.Stats

可选

仅支持默认值为None,不支持用户自定义。

debug

bool

可选

仅支持默认值为True,不支持用户自定义。

callbacks

List[Callable]

可选

仅支持默认值为None,不支持用户自定义。

使用示例

1
2
3
4
5
from torchrec.distributed.planner import EmbeddingShardingPlanner
planner = EmbeddingShardingPlanner(
 topology=topology,
 constraints=constraints,
)

参考资源

接口调用流程及示例可参见迁移与训练