昇腾社区首页
中文
注册

torch_npu.optim.NpuFusedRMSpropTF

API接口

torch_npu.optim.NpuFusedRMSpropTF(params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, decoupled_decay=False, lr_in_momentum=True)

功能描述

通过张量融合实现的 RMSpropTF 算法。

参数说明

  • params:模型参数或模型参数组
  • lr:学习率(默认值:1e-3)
  • alpha:平滑常量(默认值:0.9)
  • eps:分母防除0项,提高数值稳定性(默认值:1e-10)
  • weight_decay:权重衰减(默认值:0)
  • momentum:动量因子(默认值:0)
  • centered: 计算中心RMSProp(默认值:False)
  • decoupled_decay:权重衰减仅作用于参数(默认值:False)
  • lr_in_momentum:计算动量buffer时使用lr(默认值:True)

示例

opt = torch_npu.optim.NpuFusedRMSpropTF(model.parameters(), lr=0.001, weight_decay=0.01, momentum=0.9)