定义

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
struct SelfAttentionParam {
    enum CalcType : int {
        UNDEFINED = 0, 
        ENCODER,       
        DECODER,       
        PA_ENCODER   
        PREFIX_ENCODER,  
    };
    enum KernelType : int {
        KERNELTYPE_DEFAULT = 0,   
        KERNELTYPE_HIGH_PRECISION 
    };
    enum ClampType : int {
        CLAMP_TYPE_UNDEFINED = 0, 
        CLAMP_TYPE_MIN_MAX        
    };
    enum MaskType : int {
        MASK_TYPE_UNDEFINED = 0,             
        MASK_TYPE_NORM,                       
        MASK_TYPE_ALIBI,                     
        MASK_TYPE_NORM_COMPRESS,            
        MASK_TYPE_ALIBI_COMPRESS,           
        MASK_TYPE_ALIBI_COMPRESS_SQRT,       
        MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN,
        MASK_TYPE_SLIDING_WINDOW_NORM,
        MASK_TYPE_SLIDING_WINDOW_COMPRESS
    };
    enum KvCacheCfg :int {
        K_CACHE_V_CACHE = 0,
        K_BYPASS_V_BYPASS,  
    };
    enum ScaleType : int {
        SCALE_TYPE_TOR = 0,      
        SCALE_TYPE_LOGN,        
        SCALE_TYPE_MAX           
    };
    enum QuantType : int {
        TYPE_QUANT_UNDEFINED = 0,
        TYPE_QUANT_UNQUANT = 0, 
        TYPE_DEQUANT_FUSION,      
        TYPE_QUANT_QKV_OFFLINE, 
        TYPE_QUANT_QKV_ONLINE   
    };
    enum CacheType : int8_t {
        CACHE_TYPE_NORM = 0,
        CACHE_TYPE_SWA = 1
    };
    QuantType quantType = TYPE_QUANT_UNQUANT;
    aclDataType outDataType = ACL_DT_UNDEFINED;
    int32_t headNum = 0;
    int32_t kvHeadNum = 0;
    float qScale = 1;
    float qkScale = 1;
    bool batchRunStatusEnable = false;
    uint32_t isTriuMask = 0;
    CalcType calcType = UNDEFINED;
    KernelType kernelType = KERNELTYPE_DEFAULT;
    ClampType clampType = CLAMP_TYPE_UNDEFINED;
    float clampMin = 0;
    float clampMax = 0;
    MaskType maskType = MASK_TYPE_UNDEFINED;
    KvCacheCfg kvcacheCfg = K_CACHE_V_CACHE;
    ScaleType scaleType = SCALE_TYPE_TOR;
    InputLayout inputLayout = TYPE_BSND;
    uint32_t mlaVHeadSize = 0;
    CacheType cacheType = CACHE_TYPE_NORM;
    uint32_t windowSize = 0;
    uint8_t rsv[64] = {0};
};