EmbCacheEmbeddingBagCollection

Function

Creates a single-server table object with pooling, hash mapping, and multi-level cache.

Prototype

1
2
class EmbCacheEmbeddingBagCollection:
     def __init__(**kwargs):

Parameters

Parameter

Data Type

Required/Optional

Description

tables

List[EmbCacheEmbeddingBagConfig|EmbeddingConfig]

Mandatory

List of sparse table configuration files. The list length range is [1, 10000].

world_size

int

Mandatory

Size of world_size for distributed training. The value range is [1, 10000].

batch_size

int

Mandatory

Batch size. The value range is [1, 102400].

multi_hot_sizes

List[int]

Mandatory

List of multi-hot encoding sizes of each feature. The length of this parameter list must be the same as that of the tables list. The value range is [1, 10000]. The value range of the multi-hot encoding size in the list is [1, 102400].

is_weighted

bool

Optional

Only the default value False is supported.

need_accumulate_offset

bool

Optional

Whether to accumulate offsets. The default value is True.

device

torch.device

Optional

The default value is CPU. The compute device is the same as that of HashEmbeddingBagCollection.

embedding_optimizer_cls

Type[torch.optim.Optimizer]

Optional

Embedding optimizer type. The default value is torch.optim.Adagrad. The options are as follows:
  • torch.optim.Adagrad: Adagrad optimizer.
  • torch.optim.Adam: Adam optimizer.
  • torch.optim.SGD: SGD optimizer.

Return Value

  • Success: The EmbCacheEmbeddingBagCollection object is returned.
  • Failure: An exception is thrown.