昇腾社区首页
中文
注册

HashEmbeddingBagConfig

功能描述

HashEmbeddingBagCollection的入参,用于配置表的大小、dim、数据类型等。

函数原型

1
2
3
@dataclass
class HashEmbeddingBagConfig:
 def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

num_embeddings

int

必选

稀疏表的行数。取值范围:[1, 10亿]。

embedding_dim

int

必选

稀疏表的列数。取值范围:[8, 8192]。取值需要为8的倍数。

name

str

必选

稀疏表的名称。只能包含数字、字母和下划线。

data_type

torchrec.types.DataType

必选

稀疏表的数据类型。支持DataType.FP32。

feature_names

List[str]

必选

稀疏表查询的特征名称。只能包含数字、字母和下划线。

weight_init_max

float

可选

仅支持默认值为None,不支持用户自定义。

weight_init_min

float

可选

仅支持默认值为None,不支持用户自定义。

num_embeddings_post_pruning

int

可选

仅支持默认值为None,不支持用户自定义。

init_fn

Callable

可选

支持传入nn.Parameter类型的函数。用户需自行保证该函数的正确性。默认值为None。

need_pos

bool

可选

仅支持默认值为False,不支持用户自定义。

pooling

torchrec.modules.embedding_configs.PoolType

可选

pool操作的类型。

取值范围:

  • SUM:求和。
  • MEAN:取平均。

默认为SUM。

使用示例

1
2
3
4
5
6
7
from hybrid_torchrec import HashEmbeddingBagConfig
config = HashEmbeddingBagConfig(
 name="table",
 embedding_dim=128,
 num_embeddings=100000,
 feature_names=["feat_name1"]
)

参考资源

接口调用流程及示例,请参见迁移与训练