昇腾社区首页
中文
注册
开发者
下载

初始化

功能描述

创建多级缓存流水查表。

函数原型

1
2
class EmbCacheTrainPipelineSparseDist:
     def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

model

torch.nn.Module

必选

包含EmbCacheEmbeddingBagCollection和EmbCacheEmbeddingCollection的nn.Module对象

optimizer

torch.optim.Optimizer

必选

优化器

cpu_device

torch.device

必选

CPU设备

npu_device

torch.device

必选

NPU设备

return_loss

bool

可选

是否返回loss,默认值为False

execute_all_batches

bool

可选

是否执行所有批次,默认值为True,不支持用户自定义

apply_jit

bool

可选

是否应用JIT编译,默认值为False,不支持用户自定义

context_type

Type[EmbCacheTrainPipelineContext]

可选

上下文类型,默认值为EmbCacheTrainPipelineContext

pipeline_postproc

bool

可选

是否启用流水线后处理,默认值为False,与torchrec的TrainPipelineSparseDist一致

custom_model_fwd

Callable

可选

自定义模型前向函数,默认值为None,与torchrec的TrainPipelineSparseDist一致

custom_model_zero_grad

Callable

可选

zero_grad自定义函数,默认为None

custom_model_bwd

Callable

可选

自定义模型反向函数,默认为None

返回值说明

  • 成功:返回EmbCacheTrainPipelineSparseDist对象。
  • 失败:抛出异常。