ShardingEnv (TorchRec)

This API is an open-source API of the TorchRec and is not an external API of the Rec SDK Torch. This section describes the parameter ranges supported by the TorchRec APIs called when the Rec SDK Torch is used.

Function

Saves distributed parameters.

Prototype

1
2
3
class ShardingEnv:
    def __init__(**kwargs):
def from_process_group(cls, pg: dist.ProcessGroup) -> "ShardingEnv":

Parameters

Parameter

Type

Mandatory/Optional

Description

world_size

int

Mandatory

Number of used devices. Value range: [1, 8]

rank

int

Mandatory

Current device ID. Value range: [0, world_size – 1]

pg

dist.ProcessGroup

Mandatory

Distributed communication link. Value range: The backend must be hccl or gloo.

NOTICE:

In PyTorch, the backend_name of hccl is custom.

output_dtensor

bool

Optional

The default value is False. User-defined values are not supported.

Example

1
2
3
4
import torch.distributed as dist
from torchrec.distributed.types import ShardingEnv
host_gp = dist.new_group(backend="gloo")
host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp)

See Also

For details about the API call sequence and example, see Porting and Training.