Initialization
This API is an open-source API of the TorchRec and is not an external API of the Rec SDK Torch. This section describes the parameter ranges supported by the TorchRec APIs called when the Rec SDK Torch is used.
Function
This API is used to change the input module to a distributed module and execute the table sharding plan.
Prototype
1 2 | class DistributedModelParallel: def __init__(**kwargs): |
Parameters
Parameter |
Data Type |
Mandatory/Optional |
Description |
|---|---|---|---|
module |
nn.Module |
Mandatory |
Model to be parallelized. List of modules that contain HashEmbeddingBagCollection. |
device |
torch.device |
Optional |
Device. When the NPU is used, the value is torch.device("npu"), that is, the NPU device. The default value is torch.device("cpu"). |
plan |
ShardingPlan |
Optional |
Table sharding plan. Ensure that the input is the result returned by EmbeddingShardingPlanner.collective_plan. |
sharders |
List[ModuleSharder[nn.Module]] |
Optional |
Sharder list. When the NPU is used, only get_default_hybrid_sharders() can be passed. |
env |
ShardingEnv |
Optional |
When the device is torch.device("npu"), the value can only be the default value None and cannot be customized. |
init_data_parallel |
bool |
Optional |
When the device is torch.device("npu"), the value can only be the default value True and cannot be customized. |
init_parameters |
bool |
Optional |
When the device is torch.device("npu"), the value can only be the default value True and cannot be customized. |
data_parallel_wrapper |
torchrec.distributed.DataParallelWrapper |
Optional |
When the device is torch.device("npu"), the value can only be the default value None and cannot be customized. |
Sample
1 2 3 4 | from torchrec.distributed.model_parallel import DistributedModelParallel ddp_model = DistributedModelParallel( test_model, device=torch.device("npu"), plan=plan, sharders=get_default_hybrid_sharders(host_env=host_env) ) |
See Also
For details about the API call sequence and example, see Porting and Training.