昇腾社区首页
中文
注册

def __init__

函数功能

类初始化函数。

函数原型

def __init__(self, ori_head_num, ori_inner_dim, prefix, quant_weights=None, dtype=torch.bfloat16):

参数说明

参数名

输入/输出

类型

说明

ori_head_num

输入

int

原始的attention heads。

ori_inner_dim

输入

int

原始的inner dimension。

prefix

输入

str

该fa对应层的前缀名称。

quant_weights

输入

torch.Tensor

量化权重。

dtype

输入

float

可选输入,需要为torch.float16/torch.bfloat16类型,默认torch.bfloat16。