(beta)torch_npu.npu_bert_apply_adam
接口原型
torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, step_size=None, adam_mode=0, *, out=(var,m,v))
功能描述
adam结果计数。
参数说明
- 参数:
- var (Tensor) - float16或float32类型张量。
- m (Tensor) - 数据类型和shape与exp_avg相同。
- v (Tensor) - 数据类型和shape与exp_avg相同。
- lr (Scalar) - 数据类型与exp_avg相同。
- beta1 (Scalar) - 数据类型与exp_avg相同。
- beta2 (Scalar) - 数据类型与exp_avg相同。
- epsilon (Scalar) - 数据类型与exp_avg相同。
- grad (Tensor) - 数据类型和shape与exp_avg相同。
- max_grad_norm (Scalar) - 数据类型与exp_avg相同。
- global_grad_norm (Scalar) - 数据类型与exp_avg相同。
- weight_decay (Scalar) - 数据类型与exp_avg相同。
- step_size (Tensor,可选,默认值为None) - shape为(1, ),数据类型与exp_avg一致。
- adam_mode (Int,默认值为0) - 选择adam模式。0表示“adam”,1表示“mbert_adam”。
- 关键字参数:
- out (Tensor,可选) - 输出张量。
调用示例
>>> var_in = torch.rand(321538).uniform_(-32., 21.).npu()
>>> m_in = torch.zeros(321538).npu()
>>> v_in = torch.zeros(321538).npu()
>>> grad = torch.rand(321538).uniform_(-0.05, 0.03).npu()
>>> max_grad_norm = -1.
>>> beta1 = 0.9
>>> beta2 = 0.99
>>> weight_decay = 0.
>>> lr = 0.
>>> epsilon = 1e-06
>>> global_grad_norm = 0.
>>> var_out, m_out, v_out = torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, out=(var_in, m_in, v_in))
>>> var_out
tensor([ 14.7733, -30.1218, -1.3647, ..., -16.6840, 7.1518, 8.4872],
device='npu:0')
父主题: torch_npu