昇腾社区首页
中文
注册

SelfAttentionParam

属性

类型

默认值

描述

quant_type

torch_atb.SelfAttentionParam.QuantType

torch_atb.QuantType.TYPE_QUANT_UNQUANT

表示不进行量化操作。

out_data_type

torch_atb.AclDataType

torch_atb.AclDataType.ACL_DT_UNDEFINED

根据输入tensors自动推导输出tensors数据类型。

head_num

int

0

此默认值不可用,用户需配置此项参数。

kv_head_num

int

0

-

q_scale

float

1.0

-

qk_scale

float

1.0

-

batch_run_status_enable

bool

False

-

is_triu_mask

int

0

-

calc_type

torch_atb.SelfAttentionParam.CalcType

torch_atb.SelfAttentionParam.CalcType.UNDEFINED

decoder&encoder for flashAttention。

kernel_type

torch_atb.SelfAttentionParam.KernelType

torch_atb.SelfAttentionParam.KernelType.KERNELTYPE_DEFAULT

-

clamp_type

torch_atb.SelfAttentionParam.ClampType

torch_atb.SelfAttentionParam.ClampType.CLAMP_TYPE_UNDEFINED

不做clamp。

clamp_min

float

0.0

-

clamp_max

float

0.0

-

mask_type

torch_atb.SelfAttentionParam.MaskType

torch_atb.SelfAttentionParam.MaskType.MASK_TYPE_UNDEFINED

全0mask。

kvcache_cfg

torch_atb.SelfAttentionParam.KvCacheCfg

torch_atb.SelfAttentionParam.KvCacheCfg.K_CACHE_V_CACHE

-

scale_type

torch_atb.SelfAttentionParam.ScaleType

torch_atb.SelfAttentionParam.ScaleType.SCALE_TYPE_TOR

-

input_layout

torch_atb.InputLayout

torch_atb.InputLayout.TYPE_BSND

-

mla_v_head_size

int

0

-

cache_type

torch_atb.SelfAttentionParam.CacheType

torch_atb.SelfAttentionParam.CacheType.CACHE_TYPE_NORM

-

window_size

int

0

-

SelfAttentionParam.QuantType

枚举项:

  • TYPE_QUANT_UNQUANT
  • TYPE_DEQUANT_FUSION
  • TYPE_QUANT_QKV_OFFLINE
  • TYPE_QUANT_QKV_ONLINE

SelfAttentionParam.CalcType

枚举项:

  • UNDEFINED
  • ENCODER
  • DECODER
  • PA_ENCODER
  • PREFIX_ENCODER

SelfAttentionParam.KernelType

枚举项:

  • KERNELTYPE_DEFAULT
  • KERNELTYPE_HIGH_PRECISION

SelfAttentionParam.ClampType

枚举项:

  • CLAMP_TYPE_UNDEFINED
  • CLAMP_TYPE_MIN_MAX

SelfAttentionParam.MaskType

枚举项:

  • MASK_TYPE_UNDEFINED
  • MASK_TYPE_NORM
  • MASK_TYPE_ALIBI
  • MASK_TYPE_NORM_COMPRESS
  • MASK_TYPE_ALIBI_COMPRESS
  • MASK_TYPE_ALIBI_COMPRESS_SQRT
  • MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN
  • MASK_TYPE_SLIDING_WINDOW_NORM
  • MASK_TYPE_SLIDING_WINDOW_COMPRESS

SelfAttentionParam.KvCacheCfg

枚举项:

  • K_CACHE_V_CACHE
  • K_BYPASS_V_BYPASS

SelfAttentionParam.ScaleType

枚举项:

  • SCALE_TYPE_TOR
  • SCALE_TYPE_LOGN
  • SCALE_TYPE_MAX

SelfAttentionParam.CacheType

枚举项:

  • CACHE_TYPE_NORM
  • CACHE_TYPE_SWA

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch_atb  

def self_attention():
    self_attention_param = torch_atb.SelfAttentionParam(head_num = 24, kv_head_num = 24)
    self_attention_param.calc_type = torch_atb.SelfAttentionParam.CalcType.PA_ENCODER
    self_attention = torch_atb.Operation(self_attention_param)
    q = torch.ones(4096, 24, 64, dtype=torch.float16).npu()
    k = torch.ones(4096, 24, 64, dtype=torch.float16).npu()
    v = torch.ones(4096, 24, 64, dtype=torch.float16).npu()
    seqlen = torch.tensor([4096], dtype=torch.int32)
    intensors = [q,k,v,seqlen]
    print("intensors: ", intensors)

    def self_attention_run():
        outputs = self_attention.forward([q,k,v,seqlen])
        return [outputs]

    outputs = self_attention_run()
    print("outputs: ", outputs)

if __name__ == "__main__":
    self_attention()