昇腾社区首页
中文
注册

forward

函数功能

类初始化函数。

函数原型

def forward(self, query, key, value, seq_len_list):

参数说明

参数名

输入/输出

类型

说明

query

输入

float16/bfloat16

query的激活值,layout支持TND。

key

输入

float16/bfloat16

key的激活值,layout支持TND。

value

输入

float16/bfloat16

value的激活值,layout支持TND。

seq_len_list

输入

list[int]

seq_len_list为各batch上seq_len之和,shape为[batch size]。