SetTensorB

Applicability

Product

Supported

Atlas A3 training products/Atlas A3 inference products

Atlas A2 training products/Atlas A2 inference products

Atlas 200I/500 A2 inference products

Atlas inference product's AI Core

Atlas inference product's Vector Core

x

Atlas training products

x

Function

Sets the right matrix B for matrix multiplication.

Prototype

1
__aicore__ inline void SetTensorB(const GlobalTensor<SrcBT>& gm, bool isTransposeB = false)
1
__aicore__ inline void SetTensorB(const LocalTensor<SrcBT>& rightMatrix, bool isTransposeB = false)
1
__aicore__ inline void SetTensorB(SrcBT bScalar)

Atlas inference product's AI Core does not support the SetTensorB(SrcBT bScalar) API prototype.

Atlas 200I/500 A2 inference productss do not support the SetTensorB(SrcBT bScalar) API prototype.

Parameters

Table 1 Template parameters

Parameter

Description

SrcBT

Data type of the operand.

Table 2 Parameters

Parameter

Input/Output

Description

gm

Input

Matrix B. The type is GlobalTensor. SrcBT indicates the data type of matrix B.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half, float, bfloat16_t, int8_t, and int4b_t.

For the Atlas inference product's AI Core, the supported data type is half, float, and int8_t.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half, float, bfloat16_t, int8_t, and int4b_t.

For the Atlas 200I/500 A2 inference products, the supported data types are half, float, bfloat16_t, and int8_t.

rightMatrix

Input

Matrix B. The type is LocalTensor, and TPosition can be TSCM or VECOUT. SrcBT indicates the data type of matrix B.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half, float, bfloat16_t, int8_t, and int4b_t.

For the Atlas inference product's AI Core, the supported data types are half, float, and int8_t.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half, float, bfloat16_t, int8_t, and int4b_t.

For the Atlas 200I/500 A2 inference products, the supported data types are half, float, bfloat16_t, and int8_t.

If the initial address in the TSCM is set, the matrix can be fully loaded by default. In this case, the Iterate API does not need to transfer data from GM to A1/B1.

bScalar

Input

Value set in matrix B. Scalar data can be passed. The passed scalar data is expanded into a tensor with the shape of [1, K] to participate in matrix multiplication. The tensor values are all scalar values. For example, a developer can implement a reduce sum operation on matrix B in the K direction by setting bScalar to 1. SrcBT indicates the data type of matrix B.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half and float.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half and float.

This parameter is not supported by the Atlas inference product's AI Core.

This parameter is not supported by the Atlas 200I/500 A2 inference products.

isTransposeB

Input

Whether matrix B should be transposed.

Notes:

  • If MatmulType ISTRANS of matrix B is set to true, this parameter can be set to either true or false at runtime, allowing for alternating use of transposed and non-transposed matrix B.
  • If MatmulType ISTRANS of matrix B is set to false, this parameter can only be set to false. If it is forcibly set to true, the precision will be abnormal.
  • For input types other than half and bfloat16_t, to ensure that the L1 Buffer space calculation size is consistent between the Tiling side and the Kernel side, and to guarantee correct result precision, the value of this parameter must be the same as the value of ISTRANS defined for matrix B's MatmulType on the Kernel side and the value of isTrans of the SetBType() API on the tiling side. In other words, all three parameters must be set to true or false simultaneously.

Returns

None

Restrictions

Ensure that the size of the input TensorB address space is greater than or equal to singleK x singleN.

Example

1
2
3
4
5
REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling);
mm.SetTensorA(gm_a);
mm.SetTensorB(gm_b);    // Set the right matrix B.
mm.SetBias(gm_bias);
mm.IterateAll(gm_c);