昇腾社区首页
中文
注册
开发者
下载

EmbCacheEmbeddingBagCollection

功能描述

创建带pooling、哈希映射和多级缓存的单机表对象。

函数原型

1
2
class EmbCacheEmbeddingBagCollection:
     def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

tables

List[EmbCacheEmbeddingBagConfig|EmbeddingConfig]

必选

稀疏表配置文件列表

world_size

int

必选

分布式训练world_size大小

batch_size

int

必选

批次大小

multi_hot_sizes

List[int]

必选

每个特征的多热编码大小列表

is_weighted

bool

可选

仅支持默认值False

need_accumulate_offset

bool

可选

是否需要累积偏移量,默认值True

device

torch.device

可选

默认为CPU,计算设备和HashEmbeddingBagCollection一致。

embedding_optimizer_cls

Type[torch.optim.Optimizer]

可选

嵌入优化器类型,默认值torch.optim.Adagrad

返回值说明

  • 成功:返回EmbCacheEmbeddingBagCollection对象。
  • 失败:抛出异常。