接口调用介绍
图1 接口调用流程

以下步骤省略了具体实现,如需完整代码,请参考RecSDK-Torch Little Demo样例。关键步骤如下:
- 定义Batch。将本次训练需要的所有特征整合为一个Batch类,并实现to()、pin_memory()、record_stream()方法。
@dataclass class Batch(Pipelineable): ......
- 创建数据集。实现一个返回1中创建的Batch类型的Dataset。
class RandomRecDataset(IterableDataset[Batch]): ......
- 初始化分布式变量。
...... device = torch.device("npu") dist.init_process_group(backend="hccl") host_gp = dist.new_group(backend="gloo") host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp)
- 定义模型。将稀疏表层和Dense层的模型整合为一个Module。该Module的输入必须为1中创建的Batch类。返回为模型的loss和输出。
class TestModel(torch.nn.Module): def __init__(self, ): super().__init__() table_configs = ...... self.ebc = HashEmbeddingBagCollection(device="npu", tables=table_configs) def forward(self, batch: Batch): return loss, result
- 定义稀疏表的优化器。
test_model = TestModel(...) # Optimizer embedding_optimizer = torch.optim.Adagrad optimizer_kwargs = {"lr": 0.001, "eps": 0.1} apply_optimizer_in_backward( embedding_optimizer, test_model.ebc.parameters(), optimizer_kwargs=optimizer_kwargs, )
- 对稀疏表做分表。创建sharder,并使用EmbeddingShardingPlanner创建分表计划,将分表计划和sharder传入DistributedModelParallel中获得分布式模型。
hybrid_sharder = get_default_hybrid_sharders(host_env=host_env) constraints = {......} planner = EmbeddingShardingPlanner(......) plan = planner.collective_plan(test_model, hybrid_sharder, dist.GroupMember.WORLD) logging.info(plan) ddp_model = DistributedModelParallel( test_model, device=torch.device("npu"), plan=plan, sharders=hybrid_sharder )
- 整合优化器。分离dense和sparse的参数,并组合成一个新的优化器。
# Optimizer filter dense_optimizer = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(ddp_model.named_parameters())), lambda params: torch.optim.Adagrad(params, lr=0.1), ) optimizer = CombinedOptimizer([ddp_model.fused_optimizer, dense_optimizer])
- 创建pipeline。
pipeline = HybridTrainPipelineSparseDist( ddp_model, optimizer, device, execute_all_batches=True )
- 使用pipeline进行训练。
batched_iterator = iter(data_loader) for i in range(...): pipeline.progress(batched_iterator)
父主题: 快速入门