Training Scenarios
Setting Up a Network Based On Rec SDK Torch
You can set up a model and train it by following the steps in Quick Start.
Porting Based on Open-Source TorchRec
If you have set up a network on TorchRec, replace the APIs based on the API mapping, as shown in Table 1.
TorchRec API |
Rec SDK Torch API |
Function Description |
|---|---|---|
EmbeddingBagConfig |
HashEmbeddingBagConfig |
Sparse table configuration. |
EmbeddingBagCollection |
HashEmbeddingBagCollection |
Create a sparse table. |
get_default_sharders |
get_default_hybrid_sharders |
Obtain a table splitter. |
TrainPipelineSparseDist |
HybridTrainPipelineSparseDist |
Query a sparse table. |
API example:
- TorchRec example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.model_parallel import get_default_sharders class TestModel(torch.nn.Module): def __init__(self, *): # The API used by the Rec SDK Torch is HashEmbeddingBagCollection. self.ebc = EmbeddingBagCollection() def forward(self, batch: Batch): pass def invoke_main(): rank, world_size = get_distribute_env() device = torch.device("npu") dist.init_process_group(backend="hccl") dataset = RandomRecDataset(BATCH_SIZE, BATCH_NUM, FEAT_NAMES, ID_RANGES) data_loader = DataLoader( dataset, ) test_model = TestModel(TABLE_NAMES, FEAT_NAMES, EMBED_DIMS ,NUM_EMBEDS) ... Create a sharder. ... # The API used by the Rec SDK Torch is get_default_hybrid_sharders. hybrid_sharder = get_default_sharders() ... Create an optimizer. ... # The API used by the Rec SDK Torch is HybridTrainPipelineSparseDist. pipeline = TrainPipelineSparseDist() for i in range(20): pipeline.progress(batched_iterator)
- Rec SDK Torch example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
from hybrid_torchrec import HashEmbeddingBagCollection from hybrid_torchrec.distributed.hybrid_train_pipeline import HybridTrainPipelineSparseDist from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders ... class TestModel(torch.nn.Module): def __init__(self, *): # The API used by the native TorchRec is EmbeddingBagCollection. self.ebc = HashEmbeddingBagCollection() def forward(self, batch: Batch): pass def invoke_main(): rank, world_size = get_distribute_env() device = torch.device("npu") dist.init_process_group(backend="hccl") # Create a host connection for the Rec SDK Torch. host_gp = dist.new_group(backend="gloo") host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) dataset = RandomRecDataset(BATCH_SIZE, BATCH_NUM, FEAT_NAMES, ID_RANGES) data_loader = DataLoader( dataset, ) test_model = TestModel(TABLE_NAMES, FEAT_NAMES, EMBED_DIMS ,NUM_EMBEDS) ... Create a sharder. ... # The API used by the native TorchRec is get_default_sharders. hybrid_sharder = get_default_hybrid_sharders(host_env=host_env) ... Create an optimizer. ... # The API used by the native TorchRec is TrainPipelineSparseDist. pipeline = HybridTrainPipelineSparseDist() for i in range(20): pipeline.progress(batched_iterator)
Parent topic: Porting and Training