DistributedModelParallel(TorchRec)

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。
功能描述
将传入的Module变为分布式的Module,并执行分表计划。
函数原型
1 2 | class DistributedModelParallel: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
module |
nn.Module |
必选 |
需要并行的模型。包含HashEmbeddingBagCollection的module列表。 |
device |
torch.device |
必选 |
设备。 取值范围: torch.device("npu"):npu设备。 |
plan |
ShardingPlan |
必选 |
分表计划。 用户需保证传入的必须是EmbeddingShardingPlanner.collective_plan返回的结果。 |
sharders |
List[ModuleSharder[nn.Module]] |
必选 |
Sharder的列表。 仅支持传入get_default_hybrid_sharders() |
env |
ShardingEnv |
可选 |
仅支持默认值为None,不支持用户自定义。 |
init_data_parallel |
bool |
可选 |
仅支持默认值为True,不支持用户自定义。 |
init_parameters |
bool |
可选 |
仅支持默认值为True,不支持用户自定义。 |
data_parallel_wrapper |
torchrec.distributed.DataParallelWrapper |
可选 |
仅支持默认值为None,不支持用户自定义。 |
使用示例
1 2 3 4 | from torchrec.distributed.model_parallel import DistributedModelParallel ddp_model = DistributedModelParallel( test_model, device=torch.device("npu"), plan=plan, sharders=get_default_hybrid_sharders(host_env=host_env) ) |
参考资源
接口调用流程及示例可参见迁移与训练。
父主题: 分表接口