昇腾社区首页
中文
注册

训练场景介绍

基于Rec SDK Torch搭建网络

用户可按快速入门的步骤搭建模型并进行训练。

基于开源TorchRec进行迁移

如果用户已经在TorchRec上搭建了网络,则按照接口对应关系进行替换,如表1所示。

表1 接口对应关系

TorchRec接口

Rec SDK Torch接口

接口功能描述

EmbeddingBagConfig

HashEmbeddingBagConfig

稀疏表配置

EmbeddingBagCollection

HashEmbeddingBagCollection

创建稀疏表

get_default_sharders

get_default_hybrid_sharders

获取分表器

TrainPipelineSparseDist

HybridTrainPipelineSparseDist

查询稀疏表

接口示例:

  • TorchRec示例:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    from torchrec import TrainPipelineSparseDist
    from torchrec import EmbeddingBagCollectionSharder
    from torchrec.distributed.model_parallel import get_default_sharders
    class TestModel(torch.nn.Module):
        def __init__(self, *):
            # Rec SDK Torch 使用的接口为HashEmbeddingBagCollection
            self.ebc = EmbeddingBagCollection()
        def forward(self, batch: Batch):
            pass
    def invoke_main():
        rank, world_size = get_distribute_env()
        device = torch.device("npu")
        dist.init_process_group(backend="hccl")
    
        dataset = RandomRecDataset(BATCH_SIZE, BATCH_NUM, FEAT_NAMES, ID_RANGES)
        data_loader = DataLoader(
            dataset,
        )
        test_model = TestModel(TABLE_NAMES, FEAT_NAMES, EMBED_DIMS ,NUM_EMBEDS)
        ...
        sharder创建
        ...
        #  Rec SDK Torch 使用的接口为get_default_hybrid_sharders
        hybrid_sharder = get_default_sharders()
        ...
        优化器创建
        ...
        # Rec SDK Torch 使用的接口为HybridTrainPipelineSparseDist
        pipeline = TrainPipelineSparseDist()
        for i in range(20):
            pipeline.progress(batched_iterator)
    
  • Rec SDK Torch示例:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    from hybrid_torchrec import HybridTrainPipelineSparseDist
    from hybrid_torchrec import HybridEmbeddingBagCollectionSharder
    from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders
    class TestModel(torch.nn.Module):
        def __init__(self, *):
            # TorchRec使用的接口为EmbeddingBagCollection
            self.ebc = HashEmbeddingBagCollection()
        def forward(self, batch: Batch):
            pass
    def invoke_main():
        rank, world_size = get_distribute_env()
        device = torch.device("npu")
        dist.init_process_group(backend="hccl")
        # Rec SDK Torch创建host连接
        host_gp = dist.new_group(backend="gloo")
        host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp)
        dataset = RandomRecDataset(BATCH_SIZE, BATCH_NUM, FEAT_NAMES, ID_RANGES)
        data_loader = DataLoader(
            dataset,
        )
        test_model = TestModel(TABLE_NAMES, FEAT_NAMES, EMBED_DIMS ,NUM_EMBEDS)
        ...
        sharder创建
        ...
        # TorchRec使用的接口为get_default_sharders
        hybrid_sharder = get_default_hybrid_sharders(host_env=host_env)
        ...
        优化器创建
        ...
        # TorchRec使用的接口为TrainPipelineSparseDist
        pipeline = HybridTrainPipelineSparseDist()
        for i in range(20):
            pipeline.progress(batched_iterator)