EmbCacheEmbeddingBagCollectionSharder

Function

Creates an EmbCacheEmbeddingBagCollectionSharder to shard EmbCacheEmbeddingBagCollection to different devices.

Prototype

1
2
class EmbCacheEmbeddingBagCollectionSharder(EmbeddingBagCollectionSharder):
     def __init__(**kwargs):

Parameters

Parameter

Data Type

Mandatory/Optional

Description

cpu_device

torch.device

Mandatory

CPU device.

cpu_env

ShardingEnv

Mandatory

CPU environment configuration.

npu_device

torch.device

Mandatory

NPU device.

npu_env

ShardingEnv

Mandatory

NPU environment configuration.

fused_params

Dict[str, Any]

Optional

Fusion parameter. The default value is None, which is the same as that of EmbeddingBagCollectionSharder in the TorchRec.

qcomm_codecs_registry

Dict[str, QuantizedCommCodecs]

Optional

Quantization communication codec registry. The default value is None, which is the same as that of EmbeddingBagCollectionSharder in the TorchRec.

Return Value

  • Success: The EmbCacheEmbeddingBagCollectionSharder object is returned.
  • Failure: An exception is thrown.