EmbCacheEmbeddingBagCollection
Function
Creates a single-server table object with pooling, hash mapping, and multi-level cache.
Prototype
1 2 | class EmbCacheEmbeddingBagCollection: def __init__(**kwargs): |
Parameters
Parameter |
Data Type |
Required/Optional |
Description |
|---|---|---|---|
tables |
List[EmbCacheEmbeddingBagConfig|EmbeddingConfig] |
Mandatory |
List of sparse table configuration files. The list length range is [1, 10000]. |
world_size |
int |
Mandatory |
Size of world_size for distributed training. The value range is [1, 10000]. |
batch_size |
int |
Mandatory |
Batch size. The value range is [1, 102400]. |
multi_hot_sizes |
List[int] |
Mandatory |
List of multi-hot encoding sizes of each feature. The length of this parameter list must be the same as that of the tables list. The value range is [1, 10000]. The value range of the multi-hot encoding size in the list is [1, 102400]. |
is_weighted |
bool |
Optional |
Only the default value False is supported. |
need_accumulate_offset |
bool |
Optional |
Whether to accumulate offsets. The default value is True. |
device |
torch.device |
Optional |
The default value is CPU. The compute device is the same as that of HashEmbeddingBagCollection. |
embedding_optimizer_cls |
Type[torch.optim.Optimizer] |
Optional |
Embedding optimizer type. The default value is torch.optim.Adagrad. The options are as follows:
|
Return Value
- Success: The EmbCacheEmbeddingBagCollection object is returned.
- Failure: An exception is thrown.