KeyedOptimizerWrapper(TorchRec)

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。
功能描述
用于封装过滤掉表的参数的优化器。
函数原型
1 2 | class KeyedOptimizerWrapper: def __init__(**kwargs) |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
params |
Mapping[str, Union[torch.Tensor, ShardedTensor]] |
必选 |
参数列表。 用户需自行保证通过in_backward_optimizer_filter(TorchRec)获得。 |
optim_factory |
Callable |
必选 |
传入创建优化器的函数。用户需自行保证该函数满足输入为一个参数,输出一个torch.optim.Optimizer对象。 |
返回值说明
- 成功:参数过滤后的优化器。
- 失败:抛出异常。
使用示例
1 2 3 4 5 | from torchrec.optim.keyed import KeyedOptimizerWrapper dense_optimizer = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(model.named_parameters())), lambda params: torch.optim.Adam(params, lr=1e-1), ) |
参考资源
接口调用流程及示例可参见迁移与训练。
父主题: 优化器接口