初始化
功能描述
创建流水查表。
函数原型
1 2 | class HybridTrainPipelineSparseDist: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
model |
torch.nn.Module |
必选 |
包含EmbeddingBagCollection的nn.Module类型。 |
optimizer |
torch.optim.Optimizer |
必选 |
优化器。优化器的创建方式请参考7中定义的优化器。 |
device |
torch.device |
必选 |
设备。取值范围:torch.device("npu"):npu设备。 |
return_loss |
bool |
可选 |
是否返回loss。默认值为False。 |
pipe_n_batch |
int |
可选 |
预取n个batch做并行。取值范围:[1, 12]。默认值为6。 |
execute_all_batches |
bool |
可选 |
默认值为True。仅支持默认值,不支持用户自定义。 |
apply_jit |
bool |
可选 |
默认值为False。仅支持默认值,不支持用户自定义。 |
返回值说明
- 成功:返回pipeline。
- 失败:抛出异常。
使用示例
1 2 | from hybrid_torchrec.distributed.hybrid_train_pipeline import HybridTrainPipelineSparseDist pipeline = HybridTrainPipelineSparseDist(model, optimizer, device) |