HashEmbeddingBagConfig
功能描述
HashEmbeddingBagCollection的入参,用于配置表的大小、dim、数据类型等。
函数原型
1 2 3 | @dataclass class HashEmbeddingBagConfig: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
num_embeddings |
int |
必选 |
稀疏表的行数。取值范围:[1, 10亿]。 |
embedding_dim |
int |
必选 |
稀疏表的列数。取值范围:[8, 8192]。取值需要为8的倍数。 |
name |
str |
必选 |
稀疏表的名称。只能包含数字、字母和下划线。 |
data_type |
torchrec.types.DataType |
必选 |
稀疏表的数据类型。支持DataType.FP32。 |
feature_names |
List[str] |
必选 |
稀疏表查询的特征名称。只能包含数字、字母和下划线。 |
weight_init_max |
float |
可选 |
仅支持默认值为None,不支持用户自定义。 |
weight_init_min |
float |
可选 |
仅支持默认值为None,不支持用户自定义。 |
num_embeddings_post_pruning |
int |
可选 |
仅支持默认值为None,不支持用户自定义。 |
init_fn |
Callable |
可选 |
支持传入nn.Parameter类型的函数。用户需自行保证该函数的正确性。默认值为None。 |
need_pos |
bool |
可选 |
仅支持默认值为False,不支持用户自定义。 |
pooling |
torchrec.modules.embedding_configs.PoolType |
可选 |
pool操作的类型。 取值范围:
默认为SUM。 |
使用示例
1 2 3 4 5 6 7 | from hybrid_torchrec import HashEmbeddingBagConfig config = HashEmbeddingBagConfig( name="table", embedding_dim=128, num_embeddings=100000, feature_names=["feat_name1"] ) |
参考资源
接口调用流程及示例,请参见迁移与训练。
父主题: 创表接口