API Call Overview
Figure 1 API calling process
The following steps omit the specific implementation. For details about the complete code, see Rec SDK Torch Little Demo. The key steps are as follows:
- Define a batch.Integrate all features required for this training into a batch class and implement the to(), pin_memory(), and record_stream() methods.
@dataclass class Batch(Pipelineable): ...... - Create a dataset.Implement a dataset of the batch type created in 1.
class RandomRecDataset(IterableDataset[Batch]): ...... - Initialize distributed variables.
...... device = torch.device("npu") dist.init_process_group(backend="hccl") host_gp = dist.new_group(backend="gloo") host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) - Define a model.Integrate the sparse surface layer and dense layer into a module. The input of the module must be the batch class created in 1. The loss and output of the model are returned.
class TestModel(torch.nn.Module): def __init__(self, ): super().__init__() table_configs = ...... self.ebc = HashEmbeddingBagCollection(device="npu", tables=table_configs) def forward(self, batch: Batch): return loss, result - Define an optimizer for sparse tables.
test_model = TestModel(...) # Optimizer embedding_optimizer = torch.optim.Adagrad optimizer_kwargs = {"lr": 0.001, "eps": 0.1} apply_optimizer_in_backward( embedding_optimizer, test_model.ebc.parameters(), optimizer_kwargs=optimizer_kwargs, ) - Split the sparse table.Create a sharder, use EmbeddingShardingPlanner to create a table splitting plan, and pass the table splitting plan and sharder to DistributedModelParallel to obtain the distributed model.
hybrid_sharder = get_default_hybrid_sharders(host_env=host_env) constraints = {......} planner = EmbeddingShardingPlanner(......) plan = planner.collective_plan(test_model, hybrid_sharder, dist.GroupMember.WORLD) logging.info(plan) ddp_model = DistributedModelParallel( test_model, device=torch.device("npu"), plan=plan, sharders=hybrid_sharder ) - Integrate the optimizers.Separate the dense and sparse parameters and combine them into a new optimizer.
# Optimizer filter dense_optimizer = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(ddp_model.named_parameters())), lambda params: torch.optim.Adagrad(params, lr=0.1), ) optimizer = CombinedOptimizer([ddp_model.fused_optimizer, dense_optimizer])
- Create a pipeline.
pipeline = HybridTrainPipelineSparseDist( ddp_model, optimizer, device, execute_all_batches=True )
- Use the pipeline for training.
batched_iterator = iter(data_loader) for i in range(...): pipeline.progress(batched_iterator)
Parent topic: Quick Start