昇腾社区首页
中文
注册

接口调用介绍

图1 接口调用流程

以下步骤省略了具体实现,如需完整代码,请参考RecSDK-Torch Little Demo样例。关键步骤如下:

  1. 定义Batch。
    将本次训练需要的所有特征整合为一个Batch类,并实现to()、pin_memory()、record_stream()方法。
    @dataclass
    class Batch(Pipelineable):
          ......
  2. 创建数据集。
    实现一个返回1中创建的Batch类型的Dataset。
    class RandomRecDataset(IterableDataset[Batch]):
        ......
  3. 初始化分布式变量。
    ......
    device = torch.device("npu")
    dist.init_process_group(backend="hccl")
    host_gp = dist.new_group(backend="gloo")
    host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp)
  4. 定义模型。
    将稀疏表层和Dense层的模型整合为一个Module。该Module的输入必须为1中创建的Batch类。返回为模型的loss和输出。
    class TestModel(torch.nn.Module):
        def __init__(self, ):
            super().__init__()
            table_configs = ......
            self.ebc = HashEmbeddingBagCollection(device="npu", tables=table_configs)
         
        def forward(self, batch: Batch):
            return loss, result
  5. 定义稀疏表的优化器。
    test_model = TestModel(...)
    
     # Optimizer
     embedding_optimizer = torch.optim.Adagrad
     optimizer_kwargs = {"lr": 0.001, "eps": 0.1}
     apply_optimizer_in_backward(
         embedding_optimizer,
         test_model.ebc.parameters(),
         optimizer_kwargs=optimizer_kwargs,
     )
    
  6. 对稀疏表做分表。
    创建sharder,并使用EmbeddingShardingPlanner创建分表计划,将分表计划和sharder传入DistributedModelParallel中获得分布式模型。
    hybrid_sharder = get_default_hybrid_sharders(host_env=host_env)
    constraints = {......}
    planner = EmbeddingShardingPlanner(......)
    plan = planner.collective_plan(test_model, hybrid_sharder, dist.GroupMember.WORLD)
    logging.info(plan)
    ddp_model = DistributedModelParallel(
     test_model, device=torch.device("npu"), plan=plan, sharders=hybrid_sharder
    )
  7. 整合优化器。
    分离dense和sparse的参数,并组合成一个新的优化器。
    # Optimizer filter
    dense_optimizer = KeyedOptimizerWrapper(
     dict(in_backward_optimizer_filter(ddp_model.named_parameters())),
     lambda params: torch.optim.Adagrad(params, lr=0.1),
    )
    optimizer = CombinedOptimizer([ddp_model.fused_optimizer, dense_optimizer])
  8. 创建pipeline。
    pipeline = HybridTrainPipelineSparseDist(
     ddp_model, optimizer, device, execute_all_batches=True
    )
  9. 使用pipeline进行训练。
    batched_iterator = iter(data_loader)
    for i in range(...):
     pipeline.progress(batched_iterator)