torch_npu.optim.NpuFusedRMSpropTF
接口原型
torch_npu.optim.NpuFusedRMSpropTF(params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, decoupled_decay=False, lr_in_momentum=True)
功能描述
通过张量融合实现的 RMSpropTF 算法。
参数说明
- params:模型参数或模型参数组。
- lr:学习率(默认值:1e-3)。
- alpha:平滑常量(默认值:0.9)。
- eps:分母防除0项,提高数值稳定性(默认值:1e-10)。
- weight_decay:权重衰减(默认值:0)。
- momentum:动量因子(默认值:0)。
- centered: 计算中心RMSProp(默认值:False)。
- decoupled_decay:权重衰减仅作用于参数(默认值:False)。
- lr_in_momentum:计算动量buffer时使用lr(默认值:True)。
调用示例
opt = torch_npu.optim.NpuFusedRMSpropTF(model.parameters(), lr=0.001, weight_decay=0.01, momentum=0.9)
父主题: torch_npu.optim