RmsNormWithStrideOperation(代码开放)
产品支持情况
硬件型号  | 
是否支持  | 
|---|---|
√  | 
|
√  | 
|
x  | 
|
x  | 
|
x  | 
功能
为了使RmsNorm支持非连续Tensor,将x tensor通过(shape、strides、offset)表示。strides描述某个维度上相邻两个元素的间隔,offset表示这个Tensor的首元素相对addr的偏移。
示例:
例如现有一个shape=(6, 5)、strides=(10, 1)、offset=22的Tensor,其内存排布如下:

即该Tensor排布如上图的深色位置。这个完整的Tensor在内存排布上是不连续的,strides描述Tensor维度上相邻两个元素的间隔,如果在维度1上的stride为1, 该维度是连续的;如果在维度0上的stride为10,那么相邻的元素间隔10个元素,即非连续。offset表示这个Tensor的首元素相对addr的偏移。
计算公式

其中gamma为可训练参数,epsilon是一个取值范围为(0, 1]的浮点值,用于避免除零。
定义
struct RmsNormWithStrideParam {
    enum RmsNormType : int {
        RMS_NORM_UNDEFINED = 0, 
        RMS_NORM_NORM,  
        RMS_NORM_PRENORM,   
        RMS_NORM_POSTNORM,      
    };
    enum PrecisionMode : int {
        HIGH_PRECISION_MODE = 0, 
        HIGH_PERFORMANCE_MODE,  
    };
    enum ModelType : int {
        LLAMA_MODEL = 0, 
        GEMMA_MODEL,    
    };
    struct NormParam {
        QuantType quantType = QUANT_UNQUANT;
        float epsilon = 1e-5;
        double layerNormEps = 1e-5;
        bool rstd = false;
        PrecisionMode precisionMode = HIGH_PRECISION_MODE;
        ModelType modelType = LLAMA_MODEL;
        DynamicQuantType dynamicQuantType = DYNAMIC_QUANT_UNDEFINED;
        uint8_t rsv[32] = {0};
    };
    struct PreNormParam {
        QuantType quantType = QUANT_UNQUANT;
        float epsilon = 1e-5;
        bool hasBias = false;
        uint8_t rsv[23] = {0};
    };
    struct PostNormParam {
        QuantType quantType = QUANT_UNQUANT;
        float epsilon = 1e-5;
        bool hasBias = false;
        uint8_t rsv[23] = {0};
    };
    RmsNormType layerType = RMS_NORM_UNDEFINED;
    NormParam normParam;
    PreNormParam preNormParam;
    PostNormParam postNormParam;
    uint8_t rsv[8] = {0};
};
参数列表
成员名称  | 
描述  | 
|---|---|
layerType  | 
归一化类型。 
  | 
normParam  | 
NORM参数。具体请参见表1。  | 
preNormParam  | 
PRENORM参数。具体请参见表2。  | 
postNormParam  | 
POSTNORM参数。具体请参见表3。  | 
rsv[8]  | 
预留参数。  | 
成员名称  | 
类型  | 
默认值  | 
取值范围  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
quantType  | 
QuantType  | 
QUANT_UNQUANT  | 
QUANT_UNQUANT  | 
是  | 
在算子的计算完成后进行量化。 当前仅支持以下类型:QUANT_UNQUANT。  | 
epsilon  | 
float  | 
1e-5  | 
1e-5  | 
是  | 
Epsilon,归一化时加在分母上防止除0。  | 
layerNormEps  | 
double  | 
1e-5  | 
1e-5  | 
是  | 
Epsilon,默认为1e-5,暂不使用。  | 
rstd  | 
bool  | 
false  | 
true/false  | 
是  | 
默认为False,不支持和“precisionMode”,“modelType”同时设置。量化场景下不支持使用“rstd”。  | 
precisionMode  | 
PrecisionMode  | 
HIGH_PRECISION_MODE  | 
HIGH_PRECISION_MODE HIGH_PERFORMANCE_MODE  | 
是  | 
 不支持和“rstd”,“modelType”同时设置。输入类型只支持float16。 量化场景下不支持使用“precisionMode”,该场景下配置该参数将返回报错ERROR_INVALID_PARAM。  | 
modelType  | 
ModelType  | 
LLAMA_MODEL  | 
LLAMA_MODEL GEMMA_MODEL  | 
是  | 
设置为GEMMA_MODEL时使用gemma模型的rmsnorm计算公式。 
 不支持和“rstd”,“precisionMode”同时启用。 量化场景下不支持使用“modelType”,该场景下配置该参数将返回报错ERROR_INVALID_PARAM。  | 
dynamicQuantType  | 
DynamicQuantType  | 
DYNAMIC_QUANT_UNDEFINED  | 
DYNAMIC_QUANT_UNDEFINED  | 
否  | 
动态量化类型。默认为DYNAMIC_QUANT_UNDEFINED非动态量化。当前版本暂不支持非对称动态量化。  | 
rsv[32]  | 
uint8_t  | 
{0}  | 
[0]  | 
否  | 
预留参数。  | 
成员名称  | 
类型  | 
默认值  | 
取值范围  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
quantType  | 
QuantType  | 
QUANT_UNQUANT  | 
QUANT_UNQUANT QUANT_INT8  | 
是  | 
量化类型。  | 
epsilon  | 
float  | 
1e-5  | 
1e-5  | 
是  | 
Epsilon,归一化时加在分母上防止除0。  | 
hasBias  | 
bool  | 
false  | 
true/false  | 
否  | 
是否叠加偏置。当需要输入beta时设置为True。量化场景下不支持使用“hasBias”,该场景下配置该参数将返回报错ERROR_INVALID_PARAM。  | 
rsv[23]  | 
uint8_t  | 
{0}  | 
[0]  | 
否  | 
预留参数。  | 
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
x  | 
[dim_0, ..., dim_n-1, dim_n]  | 
float16/bf16  | 
ND  | 
最后一位的大小需要为16的倍数。  | 
gamma  | 
[1, dim_n]/[dim_i, ..., dim_n](0<=i<=n)  | 
float16/bf16  | 
ND  | 
维度为[1, dim_n],最后一维和x保持一致。  | 
xStrides  | 
[x_shape]  | 
int64  | 
ND  | 
x的stride步长,x_shape表示x的维度数,大于0,传入方式为host tensor,最后一个数需为1。  | 
xOffset  | 
[1]  | 
int64  | 
ND  | 
x的offset,需大于0,默认值为0,传入方式为host tensor。  | 
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
output  | 
[dim_0, ..., dim_n-1, dim_n]  | 
float16/bf16  | 
ND  | 
维度和x保持一致。  | 
规格约束
- 在NORM场景下不支持量化和rstd。
 - 暂不支持preNormParam和postNormParam成员。