昇腾社区首页
中文
注册
开发者
下载

EmbCacheEmbeddingCollection

功能描述

创建带哈希映射和多级缓存的单机表对象。

函数原型

1
2
class EmbCacheEmbeddingCollection:
     def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

tables

List[EmbCacheEmbeddingConfig|EmbeddingConfig]

必选

稀疏表配置列表。列表长度的取值范围为[1,10000]。

world_size

int

必选

分布式训练world_size大小,取值范围为[1,10000]。

batch_size

int

必选

批次大小,取值范围为[1,102400]。

multi_hot_sizes

List[int]

必选

每个特征的多热编码大小列表。该参数列表的长度必须与tables的列表长度相同,取值范围为[1,10000];列表中多热编码大小的取值范围为[1,102400]。

need_indices

bool

可选

是否需要索引,默认值False。

need_accumulate_offset

bool

可选

是否需要累积偏移量,默认值True。

device

torch.device

可选

默认为CPU,计算设备和EmbCacheEmbeddingBagCollection一致。

embedding_optimizer_cls

Type[torch.optim.Optimizer]

可选

嵌入优化器类型,默认值torch.optim.Adagrad。取值范围为:
  • torch.optim.Adagrad:表示Adagrad优化器。
  • torch.optim.Adam:表示Adam优化器。
  • torch.optim.SGD:表示SGD优化器。

返回值说明

  • 成功:返回EmbCacheEmbeddingCollection对象。
  • 失败:抛出异常。