昇腾社区首页
中文
注册

DistributedModelParallel(TorchRec

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

将传入的Module变为分布式的Module,并执行分表计划。

函数原型

1
2
class DistributedModelParallel:
    def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

module

nn.Module

必选

需要并行的模型。包含HashEmbeddingBagCollection的module列表。

device

torch.device

必选

设备。

取值范围:

torch.device("npu"):npu设备。

plan

ShardingPlan

必选

分表计划。

用户需保证传入的必须是EmbeddingShardingPlanner.collective_plan返回的结果。

sharders

List[ModuleSharder[nn.Module]]

必选

Sharder的列表。

仅支持传入get_default_hybrid_sharders()

env

ShardingEnv

可选

仅支持默认值为None,不支持用户自定义。

init_data_parallel

bool

可选

仅支持默认值为True,不支持用户自定义。

init_parameters

bool

可选

仅支持默认值为True,不支持用户自定义。

data_parallel_wrapper

torchrec.distributed.DataParallelWrapper

可选

仅支持默认值为None,不支持用户自定义。

使用示例

1
2
3
4
from torchrec.distributed.model_parallel import DistributedModelParallel
ddp_model = DistributedModelParallel(
 test_model, device=torch.device("npu"), plan=plan, sharders=get_default_hybrid_sharders(host_env=host_env)
)

参考资源

接口调用流程及示例可参见迁移与训练