HybridTrainPipelineSparseDist
功能描述
创建流水查表。
函数原型
1 2 | class HybridTrainPipelineSparseDist: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
model |
torch.nn.Module |
必选 |
包含EmbeddingBagCollection的nn.Module类型。 |
optimizer |
torch.optim.Optimizer |
必选 |
优化器。优化器的创建方式请参考7中定义的优化器。 |
device |
torch.device |
必选 |
设备。 取值范围: torch.device("npu"):npu设备。 |
return_loss |
bool |
可选 |
是否返回loss。 |
pipe_n_batch |
int |
可选 |
预取n个batch做并行。取值范围:[1, 12]。 |
execute_all_batches |
True |
可选 |
仅支持默认值为True,不支持用户自定义。 |
apply_jit |
False |
可选 |
仅支持默认值为False,不支持用户自定义。 |
返回值说明
- 成功:返回pipeline。
- 失败:抛出异常。
使用示例
1 2 | from hybrid_torchrec.distributed.hybrid_train_pipeline import HybridTrainPipelineSparseDist pipeline = HybridTrainPipelineSparseDist(model, optimizer, device) |
父主题: pipeline接口