HashEmbeddingBagConfig

Function

HashEmbeddingBagCollection input parameter, which is used to configure the table size, dimension, and data type.

Prototype

1
2
3
@dataclass
class HashEmbeddingBagConfig:
 def __init__(**kwargs):

Parameters

Parameter

Type

Mandatory/Optional

Description

num_embeddings

int

Mandatory

Number of rows in a sparse table. Value range: [1, 1 billion]

embedding_dim

int

Mandatory

Number of columns in a sparse table. Value range: [8, 4096] The value must be a multiple of 8.

name

str

Mandatory

Name of a sparse table. The value can contain only digits, letters, and underscores (_).

data_type

torchrec.types.DataType

Optional

Data type of a sparse table. The default value is DataType.FP32.

feature_names

List[str]

Mandatory

Name of the feature queried in a sparse table. The value can contain only digits, letters, and underscores (_).

weight_init_max

float

Optional

The default value is None or 1.0. User-defined values are not supported.

weight_init_min

float

Optional

The default value is None or 0.0. User-defined values are not supported.

num_embeddings_post_pruning

int

Optional

The default value is None. User-defined values are not supported.

init_fn

Callable

Optional

Functions of the nn.Parameter type can be passed. Ensure that the function is correct. The default value is None.

need_pos

bool

Optional

The default value is False. User-defined values are not supported.

pooling

torchrec.modules.embedding_configs.PoolType

Optional

Type of the pool operation.

Value range:

  • SUM: Summation.
  • MEAN: Mean.
  • NONE: No pooling operation is performed.

The default value is SUM.

See Also

For details about the API call sequence and example, see Porting and Training.