def __init__
函数功能
类初始化函数。
函数原型
def __init__(self, ori_head_num, ori_inner_dim, prefix, quant_weights=None, dtype=torch.bfloat16):
参数说明
参数名 |
输入/输出 |
类型 |
说明 |
---|---|---|---|
ori_head_num |
输入 |
int |
原始的attention heads。 |
ori_inner_dim |
输入 |
int |
原始的inner dimension。 |
prefix |
输入 |
str |
该fa对应层的前缀名称。 |
quant_weights |
输入 |
torch.Tensor |
量化权重。 |
dtype |
输入 |
float |
可选输入,需要为torch.float16/torch.bfloat16类型,默认torch.bfloat16。 |
父主题: class QuantFA