torch_npu.optim.NpuFusedSGD
API接口
torch_npu.optim.NpuFusedSGD(params, lr=required, momentum=MOMENTUM_MIN, dampening=DAMPENING_DEFAULT, weight_decay=WEIGHT_DECAY_MIN, nesterov=False)
功能描述
通过张量融合实现的随机梯度下降算法。
参数说明
- params:模型参数或模型参数组 。
- lr:学习率(默认值:1e-3)。
- betas:用于计算梯度及其平方的运行平均值的系数(默认值:(0.9,0.999))。
- eps:防止除0,提高数值稳定性 (默认值:1e-8)。
- weight_decay:权重衰减(默认值:0)。
- amsgrad:是否使用AMSGrad(默认值:False)。
示例
opt = torch_npu.optim.NpuFusedSGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.1)
父主题: torch_npu.optim