Saver
Function
Saves and loads multi-level cache sparse tables, including sparse table embedding and optimizer parameters corresponding to embedding.
Prototype
1 2 3 4 5 6 | class Saver: def __init__(self, rank: int = None): ... def save(self, module: torch.nn.Module, path: str) -> None: ... def load(self, module: torch.nn.Module, path: str) -> None: |
Restrictions
- The save/load API can only save/load sparse table data (including sparse table embedding and optimizer parameters corresponding to embedding) in multi-level cache mode.
- Dense data cannot be saved or loaded (you need to call the native Torch API).
- Sparse table saving/loading in pure device memory mode is not supported.
- The save/load API cannot be called during training, concurrent calling, or asynchronous calling. It can be called only when training or evaluation is not performed.
- The save/load API can only save/load the local file system.
Parameters
Parameter |
Data Type |
Mandatory/Optional |
Description |
|---|---|---|---|
rank |
int |
Optional |
Rank of the current process in the entire world_size. When the torch distributed environment has been initialized, this parameter is optional. In this case, torch.distributed.get_rank() is used to obtain the rank. Otherwise, this parameter is mandatory. |
module |
torch.nn.Module |
Mandatory |
Model object instance. The model (or submodel) must contain a model instance of the EmbCacheShardedEmbeddingBagCollection or EmbCacheShardedEmbeddingCollection type, and the depth cannot exceed 500. The requirements are met when the table creation API or table sharding API supported by the multi-level cache is used to create a model and split the model. |
path |
string |
Mandatory |
Save or load path. The length range is [1,1024].
NOTE:
The save or load path cannot contain soft links or sensitive characters (such as keys, passwords, and private keys). Special paths (such as paths under /usr) cannot be used. The path permission cannot be higher than 750. |
Return Value
- Success: The API is successfully called, and the sparse table data is saved or loaded.
- Failure: An exception is thrown.
Sample
1 2 3 4 5 6 | from torchrec_embcache.saver import Saver ... saver = Saver(rank=rank) saver.save(model, "save_dir/sparse") # Saving. saver.load (model, "save_dir/sparse") # Loading. |