Saver
功能描述
多级缓存稀疏表保存加载功能类,提供多级缓存稀疏表数据(稀疏表Embedding,Embedding对应的优化器参数等)的保存、加载接口。
函数原型
1 2 3 4 5 6 | class Saver: def __init__(self, rank: int = None): ... def save(self, module: torch.nn.Module, path: str) -> None: ... def load(self, module: torch.nn.Module, path: str) -> None: |
使用约束
- 保存/加载接口仅支持多级缓存保存/加载稀疏表相关数据(稀疏表Embedding,Embedding对应的优化器参数等)。
- 不支持保存/加载Dense数据(需自行调用Torch原生接口)。
- 不支持纯显存模式下稀疏表保存/加载。
- 保存/加载接口不支持训练过程中调用/并发调用/异步调用,仅支持未执行训练/评估时调用接口。
- 保存/加载接口仅支持保存/加载本地文件系统。
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
rank |
int |
可选 |
当前进程在整个world_size中的rank。 当torch分布式环境已初始化时,该参数为可选,此时将使用torch.distributed.get_rank()获取rank;否则该参数为必选。 |
module |
torch.nn.Module |
必选 |
模型对象实例,模型(或子模型)需包含类型为EmbCacheShardedEmbeddingBagCollection/ EmbCacheShardedEmbeddingCollection的模型实例。 使用多级缓存支持的创表接口/分表接口进行模型创建和模型分片时即满足要求。 |
path |
string |
必选 |
保存/加载路径,长度取值范围:[1, 1024]。 |
返回值说明
- 成功:接口调用无报错,保存落盘/加载稀疏表数据。
- 失败:抛出异常。
使用示例
1 2 3 4 5 6 | from torchrec_embcache.saver import Saver ... saver = Saver(rank=rank) saver.save(model, "save_dir/sparse") # 保存 saver.load(model, "save_dir/sparse") # 加载 |
父主题: 多级缓存管理接口