Initialization
Function
Creates a pipeline table lookup.
Prototype
1 2 | class HybridTrainPipelineSparseDist: def __init__(**kwargs): |
Parameters
Parameter |
Data Type |
Mandatory/Optional |
Description |
|---|---|---|---|
model |
torch.nn.Module |
Mandatory |
nn.Module that contains EmbeddingBagCollection. |
optimizer |
torch.optim.Optimizer |
Mandatory |
Optimizer. For details about how to create an optimizer, see the optimizer defined in 7. |
device |
torch.device |
Mandatory |
Device. Value range: torch.device("npu"), which indicates the NPU device. |
return_loss |
bool |
Optional |
Whether to return the loss. The default value is False. |
pipe_n_batch |
int |
Optional |
Prefetch N batches for parallel processing. Value range: [1, 12] The default value is 6. |
execute_all_batches |
bool |
Optional |
The default value is True. Only the default value is supported. |
apply_jit |
bool |
Optional |
The default value is False. Only the default value is supported. |
Return Value
- Success: The pipeline is returned.
- Failure: An exception is thrown.
Sample
1 2 | from hybrid_torchrec.distributed.hybrid_train_pipeline import HybridTrainPipelineSparseDist pipeline = HybridTrainPipelineSparseDist(model, optimizer, device) |