初始化
此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。
功能描述
将传入的Module变为分布式的Module,并执行分表计划。
函数原型
1 2 | class DistributedModelParallel: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
module |
nn.Module |
必选 |
需要并行的模型。包含HashEmbeddingBagCollection的module列表。 |
device |
torch.device |
可选 |
设备。 使用NPU时取值为torch.device("npu"),即npu设备,默认为torch.device("cpu")。 |
plan |
ShardingPlan |
可选 |
分表计划。 用户需保证传入的必须是EmbeddingShardingPlanner.collective_plan返回的结果。 |
sharders |
List[ModuleSharder[nn.Module]] |
可选 |
Sharder的列表。 使用NPU时仅支持传入get_default_hybrid_sharders() 。 |
env |
ShardingEnv |
可选 |
device为torch.device("npu")时仅支持默认值为None,不支持用户自定义。 |
init_data_parallel |
bool |
可选 |
device为torch.device("npu")时仅支持默认值为True,不支持用户自定义。 |
init_parameters |
bool |
可选 |
device为torch.device("npu")时仅支持默认值为True,不支持用户自定义。 |
data_parallel_wrapper |
torchrec.distributed.DataParallelWrapper |
可选 |
device为torch.device("npu")时仅支持默认值为None,不支持用户自定义。 |
使用示例
1 2 3 4 | from torchrec.distributed.model_parallel import DistributedModelParallel ddp_model = DistributedModelParallel( test_model, device=torch.device("npu"), plan=plan, sharders=get_default_hybrid_sharders(host_env=host_env) ) |
参考资源
接口调用流程及示例可参见迁移与训练。