Initialization

Function

Creates a multi-level cache pipeline table lookup.

Prototype

1
2
class EmbCacheTrainPipelineSparseDist:
     def __init__(**kwargs):

Parameters

Parameter

Data Type

Mandatory/Optional

Description

model

torch.nn.Module

Mandatory

nn.Module object that contains EmbCacheEmbeddingBagCollection and EmbCacheEmbeddingCollection.

optimizer

torch.optim.Optimizer

Mandatory

Optimizer.

cpu_device

torch.device

Mandatory

CPU device.

npu_device

torch.device

Mandatory

NPU device.

return_loss

bool

Optional

Whether to return the loss. The default value is False.

execute_all_batches

bool

Optional

Whether to execute all batches. The default value is True. User-defined values are not supported.

apply_jit

bool

Optional

Whether to apply JIT compilation. The default value is False. User-defined values are not supported.

context_type

Type[EmbCacheTrainPipelineContext]

Optional

Context type. The default value is EmbCacheTrainPipelineContext.

pipeline_postproc

bool

Optional

Whether to enable pipeline post-processing. The default value is False, which is the same as that of TrainPipelineSparseDist in TorchRec.

custom_model_fwd

Callable

Optional

Custom model forward function. The default value is None, which is the same as that of TrainPipelineSparseDist in TorchRec.

custom_model_zero_grad

Callable

Optional

Custom zero_grad function. The default value is None.

custom_model_bwd

Callable

Optional

Custom model backward function. The default value is None.

Return Value

  • Success: The EmbCacheTrainPipelineSparseDist object is returned.
  • Failure: An exception is thrown.