EmbCacheEmbeddingCollection
Function
Creates a single-server table object with hash mapping and multi-level cache.
Prototype
1 2 | class EmbCacheEmbeddingCollection: def __init__(**kwargs): |
Parameters
Parameter |
Data Type |
Mandatory/Optional |
Description |
|---|---|---|---|
tables |
List[EmbCacheEmbeddingConfig|EmbeddingConfig] |
Mandatory |
List of sparse table configurations. The list length range is [1, 10000]. |
world_size |
int |
Mandatory |
Size of world_size for distributed training. The value range is [1, 10000]. |
batch_size |
int |
Mandatory |
Batch size. The value range is [1, 102400]. |
multi_hot_sizes |
List[int] |
Mandatory |
List of multi-hot encoding sizes of each feature. The length of this parameter list must be the same as that of the tables list. The value range is [1, 10000]. The value range of the multi-hot encoding size in the list is [1, 102400]. |
need_indices |
bool |
Optional |
Whether an index is required. The default value is False. |
need_accumulate_offset |
bool |
Optional |
Whether to accumulate offsets. The default value is True. |
device |
torch.device |
Optional |
The default value is CPU. The compute device is the same as that of EmbCacheEmbeddingBagCollection. |
embedding_optimizer_cls |
Type[torch.optim.Optimizer] |
Optional |
Embedding optimizer type. The default value is torch.optim.Adagrad. The options are as follows:
|
Return Value
- Success: The EmbCacheEmbeddingCollection object is returned.
- Failure: An exception is thrown.