昇腾社区首页
中文
注册
开发者
下载

初始化

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

将传入的Module变为分布式的Module,并执行分表计划。

函数原型

1
2
class DistributedModelParallel:
    def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

module

nn.Module

必选

需要并行的模型。包含HashEmbeddingBagCollection的module列表。

device

torch.device

可选

设备。

使用NPU时取值为torch.device("npu"),即npu设备,默认为torch.device("cpu")。

plan

ShardingPlan

可选

分表计划。

用户需保证传入的必须是EmbeddingShardingPlanner.collective_plan返回的结果。

sharders

List[ModuleSharder[nn.Module]]

可选

Sharder的列表。

使用NPU时仅支持传入get_default_hybrid_sharders() 。

env

ShardingEnv

可选

device为torch.device("npu")时仅支持默认值为None,不支持用户自定义。

init_data_parallel

bool

可选

device为torch.device("npu")时仅支持默认值为True,不支持用户自定义。

init_parameters

bool

可选

device为torch.device("npu")时仅支持默认值为True,不支持用户自定义。

data_parallel_wrapper

torchrec.distributed.DataParallelWrapper

可选

device为torch.device("npu")时仅支持默认值为None,不支持用户自定义。

使用示例

1
2
3
4
from torchrec.distributed.model_parallel import DistributedModelParallel
ddp_model = DistributedModelParallel(
 test_model, device=torch.device("npu"), plan=plan, sharders=get_default_hybrid_sharders(host_env=host_env)
)

参考资源

接口调用流程及示例可参见迁移与训练