昇腾社区首页
中文
注册

Attention量化

此量化方式将q,k,v量化为8bit,通过减少KV Cache的显存占用,优化Decode阶段Attention算子的速度,提升吞吐。

  • Atlas 800I A2 推理服务器支持Attention量化。
  • 需要配合W8A8量化方式一起使用。
  • 仅支持LLaMA3.1-70B、Qwen2.5-72B模型。
  • 仅支持vLLM v0 engine。

和W8A8量化权重相比,使用了Attention量化的权重将在config.json的quantization_config字段中新增“fa_quant_type”描述字段、self_attn字段及下面包含的权重信息,其部分内容如下所示:

"quantization_config": {
    ...
    "fa_quant_type": "FAQuant",
    "model.layers.0.self_attn.fa_q.scale": "FAQuant",
    "model.layers.0.self_attn.fa_k.scale": "FAQuant",
    "model.layers.0.self_attn.fa_v.scale": "FAQuant",
    "model.layers.0.self_attn.fa_q.offset": "FAQuant",
    "model.layers.0.self_attn.fa_k.offset": "FAQuant",
    "model.layers.0.self_attn.fa_v.offset": "FAQuant",
    ...
}
图1 量化权重推理时流程

表1 float16权重量化后dtype及shape信息(假设原始权重的shape为[n, k])

Tensor信息

dtype

shape

q_scale

float16

[q_head_num, head_dim]

q_offset

float16

[q_head_num, head_dim]

k_scale

float16

[kv_head_num, head_dim]

k_offset

float16

[kv_head_num, head_dim]

v_scale

float16

[kv_head_num, head_dim]

v_offset

float16

[kv_head_num, head_dim]

表2 bfloat16权重量化后dtype及shape信息(假设原始权重的shape为[n, k])

Tensor信息

dtype

shape

q_scale

bfloat16

[q_head_num, head_dim]

q_offset

bfloat16

[q_head_num, head_dim]

k_scale

bfloat16

[kv_head_num, head_dim]

k_offset

bfloat16

[kv_head_num, head_dim]

v_scale

bfloat16

[kv_head_num, head_dim]

v_offset

bfloat16

[kv_head_num, head_dim]

生成权重

您需要修改modeling文件及权重路径下的config.json文件,具体修改方法参见FA量化使用说明

修改后的config文件为:
{
    "architectures": ["LlamaForCausalLM"],     
    // 新增配置     
    // --------------------------------------------------     
    "auto_map": {
    "AutoModelForCausalLM": "modeling_llama_fa3.LlamaForCausalLM"   
},     
    // --------------------------------------------------
    ...
    // 其他未修改的代码部分
    ... 
}

您可以参考LLaMA3.1-70B W8A8+Attention量化方法,使用以下指令生成W8A8 +Attention量化权重。

cd {msmodelslim安装路径}/example/Llama/
python3 quant_llama.py --model_path {浮点权重路径} --save_directory {W8A8量化权重路径} --calib_file ../common/boolq.jsonl --w_bit 8 --a_bit 8 --device_type npu --disable_level L5 --anti_method m4 --act_method 3 --use_fa_quant True