昇腾社区首页
中文
注册

KeyedOptimizerWrapper(TorchRec

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

用于封装过滤掉表的参数的优化器。

函数原型

1
2
class KeyedOptimizerWrapper:
    def __init__(**kwargs)

参数说明

参数名

类型

可选/必选

说明

params

Mapping[str, Union[torch.Tensor, ShardedTensor]]

必选

参数列表。

用户需自行保证通过in_backward_optimizer_filter(TorchRec)获得。

optim_factory

Callable

必选

传入创建优化器的函数。用户需自行保证该函数满足输入为一个参数,输出一个torch.optim.Optimizer对象。

返回值说明

  • 成功:参数过滤后的优化器。
  • 失败:抛出异常。

使用示例

1
2
3
4
5
from torchrec.optim.keyed import KeyedOptimizerWrapper
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=1e-1),
)

参考资源

接口调用流程及示例可参见迁移与训练