forward
函数功能
类初始化函数。
函数原型
def forward(self, query, key, value, seq_len_list):
参数说明
参数名 |
输入/输出 |
类型 |
说明 |
---|---|---|---|
query |
输入 |
float16/bfloat16 |
query的激活值,layout支持TND。 |
key |
输入 |
float16/bfloat16 |
key的激活值,layout支持TND。 |
value |
输入 |
float16/bfloat16 |
value的激活值,layout支持TND。 |
seq_len_list |
输入 |
list[int] |
seq_len_list为各batch上seq_len之和,shape为[batch size]。 |
父主题: class QuantFA