EmbCacheEmbeddingBagCollectionSharder
Function
Creates an EmbCacheEmbeddingBagCollectionSharder to shard EmbCacheEmbeddingBagCollection to different devices.
Prototype
1 2 | class EmbCacheEmbeddingBagCollectionSharder(EmbeddingBagCollectionSharder): def __init__(**kwargs): |
Parameters
Parameter |
Data Type |
Mandatory/Optional |
Description |
|---|---|---|---|
cpu_device |
torch.device |
Mandatory |
CPU device. |
cpu_env |
ShardingEnv |
Mandatory |
CPU environment configuration. |
npu_device |
torch.device |
Mandatory |
NPU device. |
npu_env |
ShardingEnv |
Mandatory |
NPU environment configuration. |
fused_params |
Dict[str, Any] |
Optional |
Fusion parameter. The default value is None, which is the same as that of EmbeddingBagCollectionSharder in the TorchRec. |
qcomm_codecs_registry |
Dict[str, QuantizedCommCodecs] |
Optional |
Quantization communication codec registry. The default value is None, which is the same as that of EmbeddingBagCollectionSharder in the TorchRec. |
Return Value
- Success: The EmbCacheEmbeddingBagCollectionSharder object is returned.
- Failure: An exception is thrown.
Parent topic: Table Splitting APIs