EmbCacheEmbeddingCollection

Function

Creates a single-server table object with hash mapping and multi-level cache.

Prototype

1
2
class EmbCacheEmbeddingCollection:
     def __init__(**kwargs):

Parameters

Parameter

Data Type

Mandatory/Optional

Description

tables

List[EmbCacheEmbeddingConfig|EmbeddingConfig]

Mandatory

List of sparse table configurations. The list length range is [1, 10000].

world_size

int

Mandatory

Size of world_size for distributed training. The value range is [1, 10000].

batch_size

int

Mandatory

Batch size. The value range is [1, 102400].

multi_hot_sizes

List[int]

Mandatory

List of multi-hot encoding sizes of each feature. The length of this parameter list must be the same as that of the tables list. The value range is [1, 10000]. The value range of the multi-hot encoding size in the list is [1, 102400].

need_indices

bool

Optional

Whether an index is required. The default value is False.

need_accumulate_offset

bool

Optional

Whether to accumulate offsets. The default value is True.

device

torch.device

Optional

The default value is CPU. The compute device is the same as that of EmbCacheEmbeddingBagCollection.

embedding_optimizer_cls

Type[torch.optim.Optimizer]

Optional

Embedding optimizer type. The default value is torch.optim.Adagrad. The options are as follows:
  • torch.optim.Adagrad: Adagrad optimizer.
  • torch.optim.Adam: Adam optimizer.
  • torch.optim.SGD: SGD optimizer.

Return Value

  • Success: The EmbCacheEmbeddingCollection object is returned.
  • Failure: An exception is thrown.