Initialization

Function

Creates a pipeline table lookup.

Prototype

1
2
class HybridTrainPipelineSparseDist:
    def __init__(**kwargs):

Parameters

Parameter

Data Type

Mandatory/Optional

Description

model

torch.nn.Module

Mandatory

nn.Module that contains EmbeddingBagCollection.

optimizer

torch.optim.Optimizer

Mandatory

Optimizer. For details about how to create an optimizer, see the optimizer defined in 7.

device

torch.device

Mandatory

Device. Value range: torch.device("npu"), which indicates the NPU device.

return_loss

bool

Optional

Whether to return the loss. The default value is False.

pipe_n_batch

int

Optional

Prefetch N batches for parallel processing. Value range: [1, 12] The default value is 6.

execute_all_batches

bool

Optional

The default value is True. Only the default value is supported.

apply_jit

bool

Optional

The default value is False. Only the default value is supported.

Return Value

  • Success: The pipeline is returned.
  • Failure: An exception is thrown.

Sample

1
2
from hybrid_torchrec.distributed.hybrid_train_pipeline import HybridTrainPipelineSparseDist
pipeline = HybridTrainPipelineSparseDist(model, optimizer, device)