昇腾社区首页
中文
注册

get_default_hybrid_sharders

功能描述

获取分表器。

函数原型

1
def get_default_hybrid_sharders(host_env: ShardingEnv) -> List[ModuleSharder[nn.Module]]:

参数说明

参数名

类型

可选/必选

说明

host_env

ShardingEnv

必选

传入host连接需要的通讯域。参考3的创建方法。

返回值

成功:返回默认的分表器。

失败:抛出异常

使用示例

from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders
from torchrec.distributed.types import ShardingEnv
import torch.distributed as dist

host_gp = dist.new_group(backend="gloo")
world_size = dist.get_world_size()
rank = dist.get_rank()
host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp)
hybrid_sharder = get_default_hybrid_sharders(host_env=host_env)