昇腾社区首页
中文
注册

RmsNormParam

属性

类型

默认值

描述

layer_type

torch_atb.RmsNormParam.RmsNormType

torch_atb.RmsNormParam.RmsNormType.RMS_NORM_UNDEFINED

此默认类型不可用,用户需配置此项参数。

norm_param

torch_atb.RmsNormParam.NormParam

-

-

pre_norm_param

torch_atb.RmsNormParam.PreNormParam

-

-

post_norm_param

torch_atb.RmsNormParam.PostNormParam

-

-

RmsNormParam.RmsNormType

枚举项:

  • RMS_NORM_UNDEFINED
  • RMS_NORM_PRENORM
  • RMS_NORM_NORM
  • RMS_NORM_POSTNORM

RmsNormParam.NormParam

属性

类型

默认值

描述

quant_type

torch_atb.QuantType

torch_atb.QuantType.QUANT_UNQUANT

表示不进行量化操作。

epsilon

float

1e-5

-

layer_norm_eps

double

1e-5

-

rstd

bool

False

-

precision_mode

torch_atb.RmsNormParam.PrecisionMode

torch_atb.RmsNormParam.PrecisionMode.HIGH_PRECISION_MODE

-

model_type

torch_atb.RmsNormParam.ModelType

torch_atb.RmsNormParam.ModelType.LLAMA_MODEL

-

dynamic_quant_type

torch_atb.DynamicQuantType

torch_atb.DynamicQuantType.DYNAMIC_QUANT_UNDEFINED

-

RmsNormParam.PrecisionMode

枚举项:

  • HIGH_PRECISION_MODE
  • HIGH_PERFORMANCE_MODE

RmsNormParam.ModelType

枚举项:

  • LLAMA_MODEL
  • GEMMA_MODEL

RmsNormParam.PreNormParam

属性

类型

默认值

描述

quant_type

torch_atb.QuantType

torch_atb.QuantType.QUANT_UNQUANT

表示不进行量化操作。

epsilon

float

1e-5

-

has_bias

bool

False

-

RmsNormParam.PostNormParam

属性

类型

默认值

描述

quant_type

torch_atb.QuantType

torch_atb.QuantType.QUANT_UNQUANT

表示不进行量化操作。

epsilon

float

1e-5

-

has_bias

bool

False

-

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch_atb  
import numpy as np

def rms_norm():
    rms_norm_param = torch_atb.RmsNormParam(layer_type = torch_atb.RmsNormParam.RmsNormType.RMS_NORM_NORM)
    rms_norm_param.norm_param.rstd = True
    rms_norm = torch_atb.Operation(rms_norm_param)
    shape=[8, 8, 8]
    shape_gamma=[8]
    x = torch.from_numpy(np.random.uniform(low=0, high=100, size=shape).astype(np.float32))
    gamma = torch.from_numpy(np.random.uniform(low=0, high=100, size=shape_gamma).astype(np.float32))
    in_tensors = [x.npu(), gamma.npu()]
    print("in_tensors: ", in_tensors)

    def rms_norm_run():
        rms_norm_outputs = rms_norm.forward(in_tensors)
        return rms_norm_outputs

    outputs = rms_norm_run()
    print("outputs: ", outputs)

if __name__ == "__main__":
    rms_norm()