apply_optimizer_in_backward(TorchRec)

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。
功能描述
指定表使用的优化器。
函数原型
1 2 3 4 5 | def apply_optimizer_in_backward( optimizer_class: Type[torch.optim.Optimizer], params: Iterable[torch.nn.Parameter], optimizer_kwargs: Dict[str, Any], ) -> None: |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
optimizer_class |
Type[torch.optim.Optimizer] |
必选 |
优化器类型。 取值范围:
|
params |
Iterable[torch.nn.Parameter] |
必选 |
设置优化器的torch.nn.Parameter对象。参考5传入HashEmbeddingBagCollection对象的参数。 须知:
基于性能考虑,params参数无法校验。用户需自行保证其类型正确性。 |
optimizer_kwargs |
Dict[str, Any] |
必选 |
根据optimizer_class的参数和范围进行配置,用户需自行保证该参数的范围符合对应优化器限制。 |
返回值说明
- 成功:None
- 失败:抛出异常。
使用示例
1 2 3 4 5 6 7 8 | from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward import torch embedding_optimizer = torch.optim.Adagrad optimizer_kwargs = {"lr": 0.001, "eps": 0.1} apply_optimizer_in_backward( embedding_optimizer, ebc.parameters(), optimizer_kwargs=optimizer_kwargs) |
参考资源
接口调用流程及示例可参见迁移与训练。
父主题: 优化器接口