昇腾社区首页
中文
注册
开发者
下载

aclnnRopeWithSinCosCache

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:推理网络为了提升性能,将sin和cos输入通过cache传入,执行旋转位置编码计算。

  • 计算公式:

    1、mrope模式:positions的shape输入是[3, numTokens]:

    cosSin[i]=cosSinCache[positions[i]]cosSin[i] = cosSinCache[positions[i]] cos,sin=cosSin.chunk(2,dim=1)cos, sin = cosSin.chunk(2, dim=-1) cos0=cos[0,:,:mropeSection[0]]cos0 = cos[0, :, :mropeSection[0]] cos1=cos[1,:,mropeSection[0]:(mropeSection[0]+mropeSection[1])]cos1 = cos[1, :, mropeSection[0]:(mropeSection[0] + mropeSection[1])] cos2=cos[2,:,(mropeSection[0]+mropeSection[1]):(mropeSection[0]+mropeSection[1]+mropeSection[2])]cos2 = cos[2, :, (mropeSection[0] + mropeSection[1]):(mropeSection[0] + mropeSection[1] + mropeSection[2])] cos=torch.cat((cos0,cos1,cos2),dim=1)cos = torch.cat((cos0, cos1, cos2), dim=-1) sin0=sin[0,:,:mropeSection[0]]sin0 = sin[0, :, :mropeSection[0]] sin1=sin[1,:,mropeSection[0]:(mropeSection[0]+mropeSection[1])]sin1 = sin[1, :, mropeSection[0]:(mropeSection[0] + mropeSection[1])] sin2=sin[2,:,(mropeSection[0]+mropeSection[1]):(mropeSection[0]+mropeSection[1]+mropeSection[2])]sin2 = sin[2, :, (mropeSection[0] + mropeSection[1]):(mropeSection[0] + mropeSection[1] + mropeSection[2])] sin=torch.cat((sin0,sin1,sin2),dim=1)sin= torch.cat((sin0, sin1, sin2), dim=-1) queryRot=query[...,:rotaryDim]queryRot = query[..., :rotaryDim] queryPass=query[...,rotaryDim:]queryPass = query[..., rotaryDim:]

    (1)rotate_half(GPT-NeoX style)计算模式:

    x1,x2=torch.chunk(queryRot,2,dim=1)x1, x2 = torch.chunk(queryRot, 2, dim=-1) o1[i]=x1[i]cos[i]x2[i]sin[i]o1[i] = x1[i] * cos[i] - x2[i] * sin[i] o2[i]=x2[i]cos[i]+x1[i]sin[i]o2[i] = x2[i] * cos[i] + x1[i] * sin[i] queryRot=torch.cat((o1,o2),dim=1)queryRot = torch.cat((o1, o2), dim=-1) query=torch.cat((queryRot,queryPass),dim=1)query = torch.cat((queryRot, queryPass), dim=-1)

    (2)rotate_interleaved(GPT-J style)计算模式:

    x1=queryRot[...,::2]x1 = queryRot[..., ::2] x2=queryRot[...,1::2]x2 = queryRot[..., 1::2] queryRot=torch.stack((o1,o2),dim=1)queryRot = torch.stack((o1, o2), dim=-1) query=torch.cat((queryRot,queryPass),dim=1)query = torch.cat((queryRot, queryPass), dim=-1)

    2、rope模式:positions的shape输入是[numTokens]:

    cosSin[i]=cosSinCache[positions[i]]cosSin[i] = cosSinCache[positions[i]] cos,sin=cosSin.chunk(2,dim=1)cos, sin = cosSin.chunk(2, dim=-1) queryRot=query[...,:rotaryDim]queryRot = query[..., :rotaryDim] queryPass=query[...,rotaryDim:]queryPass = query[..., rotaryDim:]

    (1)rotate_half(GPT-NeoX style)计算模式:

    x1,x2=torch.chunk(queryRot,2,dim=1)x1, x2 = torch.chunk(queryRot, 2, dim=-1) o1[i]=x1[i]cos[i]x2[i]sin[i]o1[i] = x1[i] * cos[i] - x2[i] * sin[i] o2[i]=x2[i]cos[i]+x1[i]sin[i]o2[i] = x2[i] * cos[i] + x1[i] * sin[i] queryRot=torch.cat((o1,o2),dim=1)queryRot = torch.cat((o1, o2), dim=-1) query=torch.cat((queryRot,queryPass),dim=1)query = torch.cat((queryRot, queryPass), dim=-1)

    (2)rotate_interleaved(GPT-J style)计算模式:

    x1=query_rot[...,::2]x1 = query\_rot[..., ::2] x2=query_rot[...,1::2]x2 = query\_rot[..., 1::2] queryRot=torch.stack((o1,o2),dim=1)queryRot = torch.stack((o1, o2), dim=-1) query=torch.cat((queryRot,queryPass),dim=1)query = torch.cat((queryRot, queryPass), dim=-1)

函数原型

每个算子分为,必须先调用“aclnnRopeWithSinCosCacheGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnRopeWithSinCosCache”接口执行计算。

[object Object]
[object Object]

aclnnRopeWithSinCosCacheGetWorkspaceSize

  • 参数说明

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnRopeWithSinCosCache

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

约束说明

  • 确定性计算:

    • aclnnNormRopeConcatBackward默认确定性实现。
  • queryIn、keyIn、cosSinCache只支持2维shape输入。

  • queryIn、keyIn、cosSinCache输入的数据类型需要保持一致。

  • headSize:数据类型为BFLOAT16或FLOAT16时为32的倍数,数据类型为FLOAT32时为16的倍数。

  • rotaryDim:始终小于等于headSize;数据类型为BFLOAT16或FLOAT16时为32的倍数,数据类型为FLOAT32时为16的倍数;mrope模式下应满足rotaryDim = mropeSection[0] + mropeSection[1] + mropeSection[2]。

  • 输入tensor positions的取值应小于cosSinCache的0维maxSeqLen。

  • mrope模式下,mropeSection:取值限制为[16, 24, 24],rotaryDim的取值为128。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]