Initialization
Function
Creates a multi-level cache pipeline table lookup.
Prototype
1 2 | class EmbCacheTrainPipelineSparseDist: def __init__(**kwargs): |
Parameters
Parameter |
Data Type |
Mandatory/Optional |
Description |
|---|---|---|---|
model |
torch.nn.Module |
Mandatory |
nn.Module object that contains EmbCacheEmbeddingBagCollection and EmbCacheEmbeddingCollection. |
optimizer |
torch.optim.Optimizer |
Mandatory |
Optimizer. |
cpu_device |
torch.device |
Mandatory |
CPU device. |
npu_device |
torch.device |
Mandatory |
NPU device. |
return_loss |
bool |
Optional |
Whether to return the loss. The default value is False. |
execute_all_batches |
bool |
Optional |
Whether to execute all batches. The default value is True. User-defined values are not supported. |
apply_jit |
bool |
Optional |
Whether to apply JIT compilation. The default value is False. User-defined values are not supported. |
context_type |
Type[EmbCacheTrainPipelineContext] |
Optional |
Context type. The default value is EmbCacheTrainPipelineContext. |
pipeline_postproc |
bool |
Optional |
Whether to enable pipeline post-processing. The default value is False, which is the same as that of TrainPipelineSparseDist in TorchRec. |
custom_model_fwd |
Callable |
Optional |
Custom model forward function. The default value is None, which is the same as that of TrainPipelineSparseDist in TorchRec. |
custom_model_zero_grad |
Callable |
Optional |
Custom zero_grad function. The default value is None. |
custom_model_bwd |
Callable |
Optional |
Custom model backward function. The default value is None. |
Return Value
- Success: The EmbCacheTrainPipelineSparseDist object is returned.
- Failure: An exception is thrown.