EmbCacheEmbeddingCollection
功能描述
创建带哈希映射和多级缓存的单机表对象。
函数原型
1 2 | class EmbCacheEmbeddingCollection: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
tables |
List[EmbCacheEmbeddingConfig|EmbeddingConfig] |
必选 |
稀疏表配置列表。列表长度的取值范围为[1,10000]。 |
world_size |
int |
必选 |
分布式训练world_size大小,取值范围为[1,10000]。 |
batch_size |
int |
必选 |
批次大小,取值范围为[1,102400]。 |
multi_hot_sizes |
List[int] |
必选 |
每个特征的多热编码大小列表。该参数列表的长度必须与tables的列表长度相同,取值范围为[1,10000];列表中多热编码大小的取值范围为[1,102400]。 |
need_indices |
bool |
可选 |
是否需要索引,默认值False。 |
need_accumulate_offset |
bool |
可选 |
是否需要累积偏移量,默认值True。 |
device |
torch.device |
可选 |
默认为CPU,计算设备和EmbCacheEmbeddingBagCollection一致。 |
embedding_optimizer_cls |
Type[torch.optim.Optimizer] |
可选 |
嵌入优化器类型,默认值torch.optim.Adagrad。取值范围为:
|
返回值说明
- 成功:返回EmbCacheEmbeddingCollection对象。
- 失败:抛出异常。
父主题: 创表接口