KeyedOptimizerWrapper (TorchRec)
This API is an open-source API of the TorchRec and is not an external API of the Rec SDK Torch. This section describes the parameter ranges supported by the TorchRec APIs called when the Rec SDK Torch is used.
Function
Encapsulates the optimizer that filters out table parameters.
Prototype
1 2 | class KeyedOptimizerWrapper: def __init__(**kwargs) |
Parameters
Parameter |
Data Type |
Mandatory/Optional |
Description |
|---|---|---|---|
params |
Mapping[str, Union[torch.Tensor, ShardedTensor]] |
Mandatory |
Parameter list. Ensure that the value is obtained through in_backward_optimizer_filter (TorchRec). |
optim_factory |
Callable |
Mandatory |
Function for creating an optimizer. Ensure that the function meets the requirement that one parameter is input and one torch.optim.Optimizer object is output. |
Return Value
- Success: The optimizer after parameter filtering is returned.
- Failure: An exception is thrown.
Example
1 2 3 4 5 | from torchrec.optim.keyed import KeyedOptimizerWrapper dense_optimizer = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(model.named_parameters())), lambda params: torch.optim.Adam(params, lr=1e-1), ) |
See Also
For details about the API call sequence and example, see Porting and Training.
Parent topic: Optimizer APIs