get_default_hybrid_sharders
功能描述
获取分表器。
函数原型
1 | def get_default_hybrid_sharders(host_env: ShardingEnv) -> List[ModuleSharder[nn.Module]]: |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
host_env |
ShardingEnv |
必选 |
传入host连接需要的通讯域。参考3的创建方法。 |
返回值
成功:返回默认的分表器。
失败:抛出异常
使用示例
from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders from torchrec.distributed.types import ShardingEnv import torch.distributed as dist host_gp = dist.new_group(backend="gloo") world_size = dist.get_world_size() rank = dist.get_rank() host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) hybrid_sharder = get_default_hybrid_sharders(host_env=host_env)
父主题: 分表接口