昇腾社区首页
中文
注册

torch_npu.optim.NpuFusedLamb

API接口

torch_npu.optim.NpuFusedLamb(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False, use_global_grad_norm=False)

功能描述

通过张量融合实现的 FusedLamb 算法。

参数说明

  • params:模型参数或模型参数组
  • lr:学习率。(默认值:1e-3)
  • betas: 用于计算梯度及其平方的运行平均值的系数。 (默认值:(0.9,0.999))
  • eps:分母防除0项,提高数值稳定性(默认值:1e-8)
  • weight_decay:权重衰减(默认值:0)
  • adam:将strust_ratio设置为1,退化为Adam(默认值:False)
  • use_global_grad_norm:使用全局梯度正则(默认值:False)

示例

opt = torch_npu.optim.NpuFusedLamb(model.parameters(), lr=0.001, weight_decay=0.01)