RopeQConcatOperation(代码开放)
产品支持情况
| 
          硬件型号  | 
        
          是否支持  | 
       
|---|---|
| 
          | 
        
          x  | 
       
| 
          | 
        
          √  | 
       
| 
          | 
        
          x  | 
       
| 
          | 
        
          x  | 
       
| 
          | 
        
          x  | 
       
功能
旋转位置编码后进行concat操作。
计算公式
rope计算公式:

计算图
kernel计算流程如下图所示:

定义
          1 2 3  | 
         
          struct RopeQConcatParam { uint8_t rsv[16] = {0}; };  | 
        
参数列表
| 
          成员名称  | 
        
          类型  | 
        
          默认值  | 
        
          取值范围  | 
        
          是否必选  | 
        
          描述  | 
       
|---|---|---|---|---|---|
| 
          rsv  | 
        
          uint8_t[]  | 
        
          {0}  | 
        
          [0]  | 
        
          否  | 
        
          预留字段。  | 
       
输入
| 
          参数  | 
        
          维度  | 
        
          数据类型  | 
        
          格式  | 
        
          描述  | 
       
|---|---|---|---|---|
| 
          q  | 
        
          [ntokens, hidden_size_q]  | 
        
          float16/bf16  | 
        
          ND  | 
        
          2维。  | 
       
| 
          cos  | 
        
          [ntokens, head_dim]  | 
        
          与q一致  | 
        
          ND  | 
        
          2维。  | 
       
| 
          sin  | 
        
          [ntokens, head_dim]  | 
        
          与q一致  | 
        
          ND  | 
        
          2维。  | 
       
| 
          ConcatInput  | 
        
          [ntokens, head_num, concat_size]  | 
        
          与q一致  | 
        
          ND  | 
        
          3维。  | 
       
输出
| 
          参数  | 
        
          维度  | 
        
          数据类型  | 
        
          格式  | 
        
          描述  | 
       
|---|---|---|---|---|
| 
          ropeQConcat  | 
        
          [ntokens, head_num, head_dim+concat_size]  | 
        
          与q一致  | 
        
          ND  | 
        
          3维。  | 
       
规格约束
- hidden_size_q = head_dim * head_num。
 - head_dim*sizeof(dtype) 需要32Byte对齐,即head_dim需要是16的整数倍且需小于等于64。
 - concat_size*sizeof(dtype) 需要32Byte对齐,即concat_size需要是16的整数倍。
 - 保证headDim * 26 + concatSize * 2<maxUbSize(196352) ,对head_num无限制。