EmbCacheEmbeddingBagCollection
功能描述
创建带pooling、哈希映射和多级缓存的单机表对象。
函数原型
1 2 | class EmbCacheEmbeddingBagCollection: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
tables |
List[EmbCacheEmbeddingBagConfig|EmbeddingConfig] |
必选 |
稀疏表配置文件列表 |
world_size |
int |
必选 |
分布式训练world_size大小 |
batch_size |
int |
必选 |
批次大小 |
multi_hot_sizes |
List[int] |
必选 |
每个特征的多热编码大小列表 |
is_weighted |
bool |
可选 |
仅支持默认值False |
need_accumulate_offset |
bool |
可选 |
是否需要累积偏移量,默认值True |
device |
torch.device |
可选 |
默认为CPU,计算设备和HashEmbeddingBagCollection一致。 |
embedding_optimizer_cls |
Type[torch.optim.Optimizer] |
可选 |
嵌入优化器类型,默认值torch.optim.Adagrad |
返回值说明
- 成功:返回EmbCacheEmbeddingBagCollection对象。
- 失败:抛出异常。
父主题: 创表接口