昇腾社区首页
中文
注册

torch_npu.optim.NpuFusedAdam

接口原型

torch_npu.optim.NpuFusedAdam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False)

功能描述

通过张量融合实现的Adam算法。

参数说明

  • params:模型参数或模型参数组。
  • lr:学习率(默认值:1e-3)。
  • betas:用于计算梯度及其平方的运行平均值的系数(默认值:(0.9,0.999))。
  • eps:防止除0,提高数值稳定性 (默认值:1e-8)。
  • weight_decay:权重衰减(默认值:0)。
  • amsgrad:是否使用AMSGrad(默认值:False)。

调用示例

opt = torch_npu.optim.NpuFusedAdam(model.parameters(), lr=0.1, weight_decay=0.1)