ShardingEnv (TorchRec)
This API is an open-source API of the TorchRec and is not an external API of the Rec SDK Torch. This section describes the parameter ranges supported by the TorchRec APIs called when the Rec SDK Torch is used.
Function
Saves distributed parameters.
Prototype
1 2 3 | class ShardingEnv: def __init__(**kwargs): def from_process_group(cls, pg: dist.ProcessGroup) -> "ShardingEnv": |
Parameters
Parameter |
Type |
Mandatory/Optional |
Description |
|---|---|---|---|
world_size |
int |
Mandatory |
Number of used devices. Value range: [1, 8] |
rank |
int |
Mandatory |
Current device ID. Value range: [0, world_size – 1] |
pg |
dist.ProcessGroup |
Mandatory |
Distributed communication link. Value range: The backend must be hccl or gloo. NOTICE:
In PyTorch, the backend_name of hccl is custom. |
output_dtensor |
bool |
Optional |
The default value is False. User-defined values are not supported. |
Example
1 2 3 4 | import torch.distributed as dist from torchrec.distributed.types import ShardingEnv host_gp = dist.new_group(backend="gloo") host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) |
See Also
For details about the API call sequence and example, see Porting and Training.
Parent topic: Table Splitting APIs