API Call Overview

Figure 1 API calling process

The following steps omit the specific implementation. For details about the complete code, see Rec SDK Torch Little Demo. The key steps are as follows:

  1. Define a batch.
    Integrate all features required for this training into a batch class and implement the to(), pin_memory(), and record_stream() methods.
    @dataclass
    class Batch(Pipelineable):
          ......
  2. Create a dataset.
    Implement a dataset of the batch type created in 1.
    class RandomRecDataset(IterableDataset[Batch]):
        ......
  3. Initialize distributed variables.
    ......
    device = torch.device("npu")
    dist.init_process_group(backend="hccl")
    host_gp = dist.new_group(backend="gloo")
    host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp)
  4. Define a model.
    Integrate the sparse surface layer and dense layer into a module. The input of the module must be the batch class created in 1. The loss and output of the model are returned.
    class TestModel(torch.nn.Module):
        def __init__(self, ):
            super().__init__()
            table_configs = ......
            self.ebc = HashEmbeddingBagCollection(device="npu", tables=table_configs)
         
        def forward(self, batch: Batch):
            return loss, result
  5. Define an optimizer for sparse tables.
    test_model = TestModel(...)
    
     # Optimizer
     embedding_optimizer = torch.optim.Adagrad
     optimizer_kwargs = {"lr": 0.001, "eps": 0.1}
     apply_optimizer_in_backward(
         embedding_optimizer,
         test_model.ebc.parameters(),
         optimizer_kwargs=optimizer_kwargs,
     )
    
  6. Split the sparse table.
    Create a sharder, use EmbeddingShardingPlanner to create a table splitting plan, and pass the table splitting plan and sharder to DistributedModelParallel to obtain the distributed model.
    hybrid_sharder = get_default_hybrid_sharders(host_env=host_env)
    constraints = {......}
    planner = EmbeddingShardingPlanner(......)
    plan = planner.collective_plan(test_model, hybrid_sharder, dist.GroupMember.WORLD)
    logging.info(plan)
    ddp_model = DistributedModelParallel(
     test_model, device=torch.device("npu"), plan=plan, sharders=hybrid_sharder
    )
  7. Integrate the optimizers.
    Separate the dense and sparse parameters and combine them into a new optimizer.
    # Optimizer filter
    dense_optimizer = KeyedOptimizerWrapper(
     dict(in_backward_optimizer_filter(ddp_model.named_parameters())),
     lambda params: torch.optim.Adagrad(params, lr=0.1),
    )
    optimizer = CombinedOptimizer([ddp_model.fused_optimizer, dense_optimizer])
  8. Create a pipeline.
    pipeline = HybridTrainPipelineSparseDist(
     ddp_model, optimizer, device, execute_all_batches=True
    )
  9. Use the pipeline for training.
    batched_iterator = iter(data_loader)
    for i in range(...):
     pipeline.progress(batched_iterator)