昇腾社区首页
中文
注册

apply_optimizer_in_backward(TorchRec

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

指定表使用的优化器。

函数原型

1
2
3
4
5
def apply_optimizer_in_backward(
    optimizer_class: Type[torch.optim.Optimizer],
    params: Iterable[torch.nn.Parameter],
    optimizer_kwargs: Dict[str, Any],
) -> None:

参数说明

参数名

类型

可选/必选

说明

optimizer_class

Type[torch.optim.Optimizer]

必选

优化器类型。

取值范围:

  • torch.optim.Adagrad:Adagrad优化器。
  • torch.optim.Adam:Adam优化器。
  • torch.optim.SGD:SGD优化器。

params

Iterable[torch.nn.Parameter]

必选

设置优化器的torch.nn.Parameter对象。参考5传入HashEmbeddingBagCollection对象的参数。

须知:

基于性能考虑,params参数无法校验。用户需自行保证其类型正确性。

optimizer_kwargs

Dict[str, Any]

必选

根据optimizer_class的参数和范围进行配置,用户需自行保证该参数的范围符合对应优化器限制。

返回值说明

  • 成功:None
  • 失败:抛出异常。

使用示例

1
2
3
4
5
6
7
8
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
import torch
embedding_optimizer = torch.optim.Adagrad
optimizer_kwargs = {"lr": 0.001, "eps": 0.1}
apply_optimizer_in_backward(
	embedding_optimizer,
	ebc.parameters(),
	optimizer_kwargs=optimizer_kwargs)

参考资源

接口调用流程及示例可参见迁移与训练