Initialization

This API is an open-source API of the TorchRec and is not an external API of the Rec SDK Torch. This section describes the parameter ranges supported by the TorchRec APIs called when the Rec SDK Torch is used.

Function

This API is used to change the input module to a distributed module and execute the table sharding plan.

Prototype

1
2
class DistributedModelParallel:
    def __init__(**kwargs):

Parameters

Parameter

Data Type

Mandatory/Optional

Description

module

nn.Module

Mandatory

Model to be parallelized. List of modules that contain HashEmbeddingBagCollection.

device

torch.device

Optional

Device.

When the NPU is used, the value is torch.device("npu"), that is, the NPU device. The default value is torch.device("cpu").

plan

ShardingPlan

Optional

Table sharding plan.

Ensure that the input is the result returned by EmbeddingShardingPlanner.collective_plan.

sharders

List[ModuleSharder[nn.Module]]

Optional

Sharder list.

When the NPU is used, only get_default_hybrid_sharders() can be passed.

env

ShardingEnv

Optional

When the device is torch.device("npu"), the value can only be the default value None and cannot be customized.

init_data_parallel

bool

Optional

When the device is torch.device("npu"), the value can only be the default value True and cannot be customized.

init_parameters

bool

Optional

When the device is torch.device("npu"), the value can only be the default value True and cannot be customized.

data_parallel_wrapper

torchrec.distributed.DataParallelWrapper

Optional

When the device is torch.device("npu"), the value can only be the default value None and cannot be customized.

Sample

1
2
3
4
from torchrec.distributed.model_parallel import DistributedModelParallel
ddp_model = DistributedModelParallel(
 test_model, device=torch.device("npu"), plan=plan, sharders=get_default_hybrid_sharders(host_env=host_env)
)

See Also

For details about the API call sequence and example, see Porting and Training.