昇腾社区首页
中文
注册

__init__接口

接口功能

提供generator初始化,包括model和sampler初始化,以及权重加载、kvcache分配等功能。

接口实现

class GeneratorTorch(GeneratorBackend):
    cache_pool: CachePool = None
    def __init__(self, model_config):
        super().__init__(model_config)
        self.tokenizer = self.model_wrapper.tokenizer
        self.device = self.model_wrapper.device
        self.rank = self.model_wrapper.rank
class GeneratorBackend:
    def __init__(self, model_config):
        backend_type = model_config.get('backend_type', None)
        num_threads = model_config.get('num_threads', 8)
        self.rank = model_config.get('rank', None)
        self.world_size = model_config.get('world_size', None)
        sampler_config = SamplerConfig(
            backend_type=backend_type,
            npu_id=model_config.get('npu_device_id', None),
            num_threads=num_threads,
            rank=self.rank
        )
        self.model_wrapper = get_model_wrapper(model_config, backend_type)
        self.sampler = Sampler(sampler_config)
        self.model_info = self.model_wrapper.model_info
        self.max_position_embeddings = self.model_wrapper.max_position_embeddings

参数说明

model_config为模型配置,支持字典格式输入,主要配置信息参考如下:

参数名称

是否必选

配置名称

类型

默认值

描述

model_config

必选

backend_type

string

None

模型后端类型,支持'atb'和'ms'

world_size

int

None

TP并行数

rank

int

None

当前进程在TP并行进程中的序号,从0开始计数,此值应小于world_size

npu_device_id

int

None

当前进程所使用的可见npu设备序号,此值应小于环境变量ASCEND_RT_VISIBLE_DEVICES设置的可见设备数量(未设置ASCEND_RT_VISIBLE_DEVICES时,默认所有设备可见)

num_threads

int

8

后处理线程数

安全说明

调用API 出现异常情况时,会直接抛出异常信息。例如:

  • 对于加载的权重文件较大的情况,会直接抛NPU Out Of Memory(OOM)异常, 异常信息为:RuntimeError:NPU out of memory. Tried to allocate XXX MiB。
  • 模型权重文件中未配置torch_dtype字段,会直接抛出异常,异常信息为:unsupported type: XXX, 此类型从权重文件config.json中的`torch_dtype`字段中获取;若config.json中不存在此字段,请新增;当前此字段仅接受`float16`和`bfloat16`两种类型,各模型具体支持的类型不同,请参考模型README文件。
  • 模型权重文件中配置的torch_dtype字段不支持时,会直接抛出异常,异常信息为:unsupported type: XXX, 当前仅支持`float16`类型,请修改权重文件config.json中的`torch_dtype`字段。