初始化
功能描述
创建多级缓存流水查表。
函数原型
1 2 | class EmbCacheTrainPipelineSparseDist: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
model |
torch.nn.Module |
必选 |
包含EmbCacheEmbeddingBagCollection和EmbCacheEmbeddingCollection的nn.Module对象 |
optimizer |
torch.optim.Optimizer |
必选 |
优化器 |
cpu_device |
torch.device |
必选 |
CPU设备 |
npu_device |
torch.device |
必选 |
NPU设备 |
return_loss |
bool |
可选 |
是否返回loss,默认值为False |
execute_all_batches |
bool |
可选 |
是否执行所有批次,默认值为True,不支持用户自定义 |
apply_jit |
bool |
可选 |
是否应用JIT编译,默认值为False,不支持用户自定义 |
context_type |
Type[EmbCacheTrainPipelineContext] |
可选 |
上下文类型,默认值为EmbCacheTrainPipelineContext |
pipeline_postproc |
bool |
可选 |
是否启用流水线后处理,默认值为False,与torchrec的TrainPipelineSparseDist一致 |
custom_model_fwd |
Callable |
可选 |
自定义模型前向函数,默认值为None,与torchrec的TrainPipelineSparseDist一致 |
custom_model_zero_grad |
Callable |
可选 |
zero_grad自定义函数,默认为None |
custom_model_bwd |
Callable |
可选 |
自定义模型反向函数,默认为None |
返回值说明
- 成功:返回EmbCacheTrainPipelineSparseDist对象。
- 失败:抛出异常。