提供Generator初始化,包括Model和Sampler初始化,以及权重加载、KV Cache分配等功能。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | class GeneratorTorch(GeneratorBackend): """The interface class for using `torch` backend. The interface class exposed to third-party service frameworks in scenarios where the `torch` backend is used. It mainly provides forward inference and sampling functions. Its sampling function is implemented by the `sample` method of the base class `GeneratorBackend`. Attributes: cache_pool: A pool used for storing the kv cache. Args: model_config: A dictionary containing the model configuration as detailed in `mindie_llm.text_generator.utils.config.ModelConfig`. """ cache_pool: CachePool = None def __init__(self, model_config: Dict[str, Any]) -> None: check_model_config(model_config) super().__init__(model_config) self.tokenizer = self.model_wrapper.tokenizer self.device = self.model_wrapper.device self.rank = self.model_wrapper.rank self.adapter_manager = self.model_wrapper.adapter_manager |
model_config为模型配置,支持字典格式输入,主要配置信息参考如下:
参数名称 |
是否必选 |
配置名称 |
类型 |
默认值 |
描述 |
---|---|---|---|---|---|
model_config |
必选 |
backend_type |
string |
None |
模型后端类型,支持'atb'和'ms'。 |
world_size |
int |
None |
TP并行数。 |
||
rank |
int |
None |
当前进程在TP并行进程中的序号,从0开始计数,此值应小于world_size。 |
||
npu_device_id |
int |
None |
当前进程所使用的可见npu设备序号,此值应小于环境变量ASCEND_RT_VISIBLE_DEVICES设置的可见设备数量。(未设置ASCEND_RT_VISIBLE_DEVICES时,默认所有设备可见。) |
||
num_threads |
int |
8 |
后处理线程数。 |
||
local_rank |
int |
None |
当前进程在TP并行进程中的当前机器上的序号,从0开始计数,此值应小于当前机器的卡数。 |
||
trust_remote_code |
bool |
False |
是否信任模型权重路径下的自定义代码文件。默认不执行。若此参数设置为True,则transformers会执行用户权重路径下的自定义代码文件,这些代码文件的功能的安全性需由用户保证,请提前做好安全性检查。 |
||
可选 |
model_name |
string |
None |
模型名称。 由大写字母、小写字母、数字、中划线、点和下划线组成,且不以中划线、点和下划线作为开头和结尾,字符串长度小于或等于256。 |
|
inference_mode |
string |
None |
根据使能的插件情况获取的inference_mode,作为组图的输入信息。 |
||
max_position_embeddings |
int |
None |
模型可接受的最大上下文长度。 |
||
load_tokenizer |
bool |
True |
是否加载tokenizer,设为False则不加载默认的tokenizer。 |
||
tokenizer_path |
string |
None |
自定义加载tokenizer的路径,为None时默认使用模型权重路径。 |
调用API出现异常情况时,会直接抛出异常信息。例如: