Training Scenarios

Setting Up a Network Based On Rec SDK Torch

You can set up a model and train it by following the steps in Quick Start.

Porting Based on Open-Source TorchRec

If you have set up a network on TorchRec, replace the APIs based on the API mapping, as shown in Table 1.

Table 1 API mapping

TorchRec API

Rec SDK Torch API

Function Description

EmbeddingBagConfig

HashEmbeddingBagConfig

Sparse table configuration.

EmbeddingBagCollection

HashEmbeddingBagCollection

Create a sparse table.

get_default_sharders

get_default_hybrid_sharders

Obtain a table splitter.

TrainPipelineSparseDist

HybridTrainPipelineSparseDist

Query a sparse table.

API example:

  • TorchRec example:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist
    from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
    from torchrec.distributed.model_parallel import get_default_sharders
    class TestModel(torch.nn.Module):
        def __init__(self, *):
            # The API used by the Rec SDK Torch is HashEmbeddingBagCollection.
            self.ebc = EmbeddingBagCollection()
        def forward(self, batch: Batch):
            pass
    def invoke_main():
        rank, world_size = get_distribute_env()
        device = torch.device("npu")
        dist.init_process_group(backend="hccl")
    
        dataset = RandomRecDataset(BATCH_SIZE, BATCH_NUM, FEAT_NAMES, ID_RANGES)
        data_loader = DataLoader(
            dataset,
        )
        test_model = TestModel(TABLE_NAMES, FEAT_NAMES, EMBED_DIMS ,NUM_EMBEDS)
        ...
        Create a sharder.
        ...
        # The API used by the Rec SDK Torch is get_default_hybrid_sharders.
        hybrid_sharder = get_default_sharders()
        ...
        Create an optimizer.
        ...
        # The API used by the Rec SDK Torch is HybridTrainPipelineSparseDist.
        pipeline = TrainPipelineSparseDist()
        for i in range(20):
            pipeline.progress(batched_iterator)
    
  • Rec SDK Torch example:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    from hybrid_torchrec import HashEmbeddingBagCollection
    from hybrid_torchrec.distributed.hybrid_train_pipeline import HybridTrainPipelineSparseDist
    from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders
    ...
    class TestModel(torch.nn.Module):
        def __init__(self, *):
            # The API used by the native TorchRec is EmbeddingBagCollection.
            self.ebc = HashEmbeddingBagCollection()
        def forward(self, batch: Batch):
            pass
    def invoke_main():
        rank, world_size = get_distribute_env()
        device = torch.device("npu")
        dist.init_process_group(backend="hccl")
        # Create a host connection for the Rec SDK Torch.
        host_gp = dist.new_group(backend="gloo")
        host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp)
        dataset = RandomRecDataset(BATCH_SIZE, BATCH_NUM, FEAT_NAMES, ID_RANGES)
        data_loader = DataLoader(
            dataset,
        )
        test_model = TestModel(TABLE_NAMES, FEAT_NAMES, EMBED_DIMS ,NUM_EMBEDS)
        ...
        Create a sharder.
        ...
        # The API used by the native TorchRec is get_default_sharders.
        hybrid_sharder = get_default_hybrid_sharders(host_env=host_env)
        ...
        Create an optimizer.
        ...
        # The API used by the native TorchRec is TrainPipelineSparseDist.
        pipeline = HybridTrainPipelineSparseDist()
        for i in range(20):
            pipeline.progress(batched_iterator)