昇腾社区首页
中文
注册
开发者
下载

EmbCacheEmbeddingCollection

功能描述

创建带哈希映射和多级缓存的单机表对象。

函数原型

1
2
class EmbCacheEmbeddingCollection:
     def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

tables

List[EmbCacheEmbeddingConfig|EmbeddingConfig]

必选

稀疏表配置列表

world_size

int

必选

分布式训练world_size大小

batch_size

int

必选

批次大小

multi_hot_sizes

List[int]

必选

每个特征的多热编码大小列表

need_indices

bool

可选

是否需要索引,默认值False

need_accumulate_offset

bool

可选

是否需要累积偏移量,默认值True

device

torch.device

可选

默认为CPU,计算设备和EmbCacheEmbeddingBagCollection一致。

embedding_optimizer_cls

Type[torch.optim.Optimizer]

可选

嵌入优化器类型,默认值torch.optim.Adagrad

返回值说明

  • 成功:返回EmbCacheEmbeddingCollection对象。
  • 失败:抛出异常。