EmbCacheEmbeddingCollectionSharder
功能描述
初始化EmbCacheEmbeddingCollectionSharder分表器,用于将EmbCacheEmbeddingCollection分片到不同的设备上。
函数原型
1 2 | class EmbCacheEmbeddingCollectionSharder(EmbeddingCollectionSharder): def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
cpu_device |
torch.device |
必选 |
CPU设备 |
cpu_env |
ShardingEnv |
必选 |
CPU环境配置 |
npu_device |
torch.device |
必选 |
NPU设备 |
npu_env |
ShardingEnv |
必选 |
NPU环境配置 |
fused_params |
Dict[str, Any] |
可选 |
融合参数,默认值None, 和torchrec的EmbeddingBagCollectionSharder一致 |
qcomm_codecs_registry |
Dict[str, QuantizedCommCodecs] |
可选 |
量化通信编解码器注册表,默认值None, 和torchrec的EmbeddingBagCollectionSharder一致 |
返回值说明
- 成功:返回EmbCacheEmbeddingCollectionSharder对象。
- 失败:抛出异常。
父主题: 分表接口