昇腾社区首页
中文
注册

HybridTrainPipelineSparseDist

功能描述

创建流水查表。

函数原型

1
2
class HybridTrainPipelineSparseDist:
    def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

model

torch.nn.Module

必选

包含EmbeddingBagCollection的nn.Module类型。

optimizer

torch.optim.Optimizer

必选

优化器。优化器的创建方式请参考7中定义的优化器。

device

torch.device

必选

设备。

取值范围:

torch.device("npu"):npu设备。

return_loss

bool

可选

是否返回loss。

pipe_n_batch

int

可选

预取n个batch做并行。取值范围:[1, 12]。

execute_all_batches

True

可选

仅支持默认值为True,不支持用户自定义。

apply_jit

False

可选

仅支持默认值为False,不支持用户自定义。

返回值说明

  • 成功:返回pipeline
  • 失败:抛出异常。

使用示例

1
2
from hybrid_torchrec.distributed.hybrid_train_pipeline import HybridTrainPipelineSparseDist
pipeline = HybridTrainPipelineSparseDist(model, optimizer, device)