昇腾社区首页
中文
注册
开发者
下载

Saver

功能描述

多级缓存稀疏表保存加载功能类,提供多级缓存稀疏表数据(稀疏表Embedding,Embedding对应的优化器参数等)的保存、加载接口。

函数原型

1
2
3
4
5
6
class Saver:
    def __init__(self, rank: int = None):
    ...
    def save(self, module: torch.nn.Module, path: str) -> None:
    ...
    def load(self, module: torch.nn.Module, path: str) -> None:

使用约束

  1. 保存/加载接口仅支持多级缓存保存/加载稀疏表相关数据(稀疏表Embedding,Embedding对应的优化器参数等)。
  2. 不支持保存/加载Dense数据(需自行调用Torch原生接口)。
  3. 不支持纯显存模式下稀疏表保存/加载。
  4. 保存/加载接口不支持训练过程中调用/并发调用/异步调用,仅支持未执行训练/评估时调用接口。
  5. 保存/加载接口仅支持保存/加载本地文件系统。

参数说明

参数名

类型

可选/必选

说明

rank

int

可选

当前进程在整个world_size中的rank。

当torch分布式环境已初始化时,该参数为可选,此时将使用torch.distributed.get_rank()获取rank;否则该参数为必选。

module

torch.nn.Module

必选

模型对象实例,模型(或子模型)需包含类型为EmbCacheShardedEmbeddingBagCollection/

EmbCacheShardedEmbeddingCollection的模型实例。

使用多级缓存支持的创表接口/分表接口进行模型创建和模型分片时即满足要求。

path

string

必选

保存/加载路径,长度取值范围[1, 1024]。

返回值说明

  • 成功:接口调用无报错,保存落盘/加载稀疏表数据。
  • 失败:抛出异常。

使用示例

1
2
3
4
5
6
from torchrec_embcache.saver import Saver
...
saver = Saver(rank=rank)
saver.save(model, "save_dir/sparse")  # 保存
saver.load(model, "save_dir/sparse")  # 加载