昇腾社区首页
中文
注册

ShardingEnv(TorchRec

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

保存分布式相关参数。

函数原型

1
2
class ShardingEnv:
    def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

world_size

int

必选

使用的卡数。取值范围:[1,8]

rank

int

必选

当前的卡号。取值范围:[0,world_size -1]

pg

dist.ProcessGroup

必选

分布式通讯链接。取值范围:只支持backend为hccl和gloo的链接。

须知:

“hccl”在PyTorch里面的backend_name为custom。

output_dtensor

bool

可选

仅支持默认值为False,不支持用户自定义。

使用示例

1
2
3
4
import torch.distributed as dist
from torchrec.distributed.types import ShardingEnv
host_gp = dist.new_group(backend="gloo")
host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp)

参考资源

接口调用流程及示例可参见迁移与训练