训练场景介绍
基于Rec SDK Torch搭建网络
用户可按快速入门的步骤搭建模型并进行训练。
基于开源TorchRec进行迁移
如果用户已经在TorchRec上搭建了网络,则按照接口对应关系进行替换,如表1所示。
TorchRec接口 |
Rec SDK Torch接口 |
接口功能描述 |
---|---|---|
EmbeddingBagConfig |
HashEmbeddingBagConfig |
稀疏表配置 |
EmbeddingBagCollection |
HashEmbeddingBagCollection |
创建稀疏表 |
get_default_sharders |
get_default_hybrid_sharders |
获取分表器 |
TrainPipelineSparseDist |
HybridTrainPipelineSparseDist |
查询稀疏表 |
接口示例:
- TorchRec示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
from torchrec import TrainPipelineSparseDist from torchrec import EmbeddingBagCollectionSharder from torchrec.distributed.model_parallel import get_default_sharders class TestModel(torch.nn.Module): def __init__(self, *): # Rec SDK Torch 使用的接口为HashEmbeddingBagCollection self.ebc = EmbeddingBagCollection() def forward(self, batch: Batch): pass def invoke_main(): rank, world_size = get_distribute_env() device = torch.device("npu") dist.init_process_group(backend="hccl") dataset = RandomRecDataset(BATCH_SIZE, BATCH_NUM, FEAT_NAMES, ID_RANGES) data_loader = DataLoader( dataset, ) test_model = TestModel(TABLE_NAMES, FEAT_NAMES, EMBED_DIMS ,NUM_EMBEDS) ... sharder创建 ... # Rec SDK Torch 使用的接口为get_default_hybrid_sharders hybrid_sharder = get_default_sharders() ... 优化器创建 ... # Rec SDK Torch 使用的接口为HybridTrainPipelineSparseDist pipeline = TrainPipelineSparseDist() for i in range(20): pipeline.progress(batched_iterator)
- Rec SDK Torch示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
from hybrid_torchrec import HybridTrainPipelineSparseDist from hybrid_torchrec import HybridEmbeddingBagCollectionSharder from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders class TestModel(torch.nn.Module): def __init__(self, *): # TorchRec使用的接口为EmbeddingBagCollection self.ebc = HashEmbeddingBagCollection() def forward(self, batch: Batch): pass def invoke_main(): rank, world_size = get_distribute_env() device = torch.device("npu") dist.init_process_group(backend="hccl") # Rec SDK Torch创建host连接 host_gp = dist.new_group(backend="gloo") host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) dataset = RandomRecDataset(BATCH_SIZE, BATCH_NUM, FEAT_NAMES, ID_RANGES) data_loader = DataLoader( dataset, ) test_model = TestModel(TABLE_NAMES, FEAT_NAMES, EMBED_DIMS ,NUM_EMBEDS) ... sharder创建 ... # TorchRec使用的接口为get_default_sharders hybrid_sharder = get_default_hybrid_sharders(host_env=host_env) ... 优化器创建 ... # TorchRec使用的接口为TrainPipelineSparseDist pipeline = HybridTrainPipelineSparseDist() for i in range(20): pipeline.progress(batched_iterator)
父主题: 迁移与训练