LogSoftMax

Applicability

Product

Supported

Atlas A3 training products / Atlas A3 inference products

Atlas A2 training products / Atlas A2 inference products

Atlas 200I/500 A2 inference products

x

Atlas inference product 's AI Core

Atlas inference product 's Vector Core

x

Atlas training products

x

Function

Performs LogSoftmax computation on the input tensor. Below is the formula.

For ease of understanding, the formula expressed through a Python script is as follows, where src is the source operand (input), and dst, sum, and max are the destination operands (output).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def log_softmax(src):
    # Perform rowmax (taking the maximum value by row) processing along the last axis.
    max = np.max(src, axis=-1, keepdims=True)
    sub = src - max
    exp = np.exp(sub)
   # Perform rowsum (taking the sum by row) processing along the last axis.
    sum = np.sum(exp, axis=-1, keepdims=True)
    dst = exp / sum
    dst = np.log10(dst)
    return dst, max, sum

Principles

The following figure shows the internal algorithm diagram of the LogSoftMax high-level APIs by taking the input tensor of the float type, in ND format, and with shape [m, k] as an example.

Figure 1 Diagram of the LogSoftMax algorithm

The computation process is divided into the following steps, all of which are performed on vectors:

  1. reducemax: Compute the maximum value of each row of input x to obtain [m, 1]. The computation result is saved to the temporary space temp.
  2. broadcast: Pad the data [m, 1] in temp by data block. For example, for the float type, extend [m, 1] to [m, 8] and output max.
  3. sub: Subtract max from all data of input x by row.
  4. exp: Compute exp for all data after sub.
  5. reducesum: Sum up each row of data after exp is performed to obtain [m, 1]. The computation result is saved to temp.
  6. broadcast: Pad [m, 1] in temp by data block. For example, for the float type, extend [m, 1] to [m, 8] and output sum.
  7. div: Divide all data generated after exp by sum at each row.
  8. log: Perform log10 computation on all data after div by row and output y.

Prototype

1
2
template <typename T, bool isReuseSource = false, bool isDataFormatNZ = false>
__aicore__ inline void LogSoftMax(const LocalTensor<T>& dst, const LocalTensor<T>& sum, const LocalTensor<T>& max, const LocalTensor<T>& src, const LocalTensor<uint8_t>& sharedTmpBuffer, const LogSoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})

Due to the complex mathematical computation involved in the internal implementation of this API, extra temporary space is required to store intermediate variables generated during computation. The temporary space can be passed by developers through the sharedTmpBuffer input parameter. To obtain the size of the temporary space (BufferSize) to be reserved, use the API provided in LogSoftMax Tiling.

Parameters

Table 1 Template parameters

Parameter

Description

T

Data type of the operand.

For the Atlas A3 training products / Atlas A3 inference products , the supported data types are half and float.

For the Atlas A2 training products / Atlas A2 inference products , the supported data types are half and float.

For the Atlas inference product 's AI Core, the supported data types are half and float.

isReuseSource

Whether the source operand can be modified. This parameter is reserved. Pass the default value false.

isDataFormatNZ

Whether the source operand is in NZ format. The default value is false.

Table 2 API parameters

Parameter

Input/Output

Description

dst

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The length of the last axis must be 32-byte aligned.

sum

Output

reduceSum operand.

The reduceSum operand must have the same data type as the destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The length of the last axis of sum is fixed at 32 bytes. It is the length of a data block. All data in this data block shares a common value. For example, in the half data type, all 16 numbers in this data block possess an identical reducesum value.
  • The length of the non-last axis is the same as that of the destination operand.

max

Output

reduceMax operand.

The reduceMax operand must have the same data type as the destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The length of the last axis of max is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the half data type, all 16 numbers in this data block possess an identical reducemax value.
  • The length of the non-last axis is the same as that of the destination operand.

src

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The source operand must have the same data type as the destination operand.

sharedTmpBuffer

Input

Temporary buffer. For details about how to obtain the temporary space size (BufferSize), see LogSoftMax Tiling.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

tiling

Input

Tiling information required for LogSoftMax computation. For details about how to obtain the tiling information, see LogSoftMax Tiling.

softmaxShapeInfo

Input

Shape of src, which is of the SoftMaxShapeInfo type. The definition is as follows:

1
2
3
4
5
6
struct SoftMaxShapeInfo {
    uint32_t srcM; // Product of lengths of non-last axes
    uint32_t srcK; // Length of the last axis, which must be 32-byte aligned
uint32_t oriSrcM; // Product of lengths of original non-last axes
uint32_t oriSrcK; // Length of the original last axis
};

Note that when the input and output data format is NZ (FRACTAL_NZ), the length of the last axis is the length of the reduce axis, which is W0 × W1 in Figure 2. The length of the non-last axis is H0 × H1.

Returns

None

Restrictions

  • The value range of the input source data must be [-2147483647.0, 2147483647.0]. If the input is not within the range, the output is invalid.
  • The source operand address must not overlap the destination operand address.
  • The address of sharedTmpBuffer must not overlap that of the source or destination operand.
  • For details about the operand address alignment requirements, see General Address Alignment Restrictions.
  • When srcM ! is set to oriSrcM or srcK ! is set to oriSrcK in softmaxShapeInfo, for the original input (oriSrcM, oriSrcK) on the GM, you need to pad data to (srcM, srcK) in the M or K direction. The padded data will be involved in some computation. In the scenario where the input and output are reused, the computation result of the API will overwrite the original data padded to the srcTensor. In the scenario where the input and output are not reused, the computation result of the API will overwrite the data in dstTensor corresponding to the padded position of srcTensor.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
//DTYPE_X, DTYPE_A, DTYPE_B and DTYPE_C indicate the data types of the source operand, destination operand, maxLocal operand, and sumLocal operand, respectively.
pipe.InitBuffer(inQueueX, BUFFER_NUM, totalLength * sizeof(DTYPE_X));
pipe.InitBuffer(outQueueA, BUFFER_NUM, totalLength * sizeof(DTYPE_A));
pipe.InitBuffer(outQueueB, BUFFER_NUM, outsize * sizeof(DTYPE_B));
pipe.InitBuffer(outQueueC, BUFFER_NUM, outsize * sizeof(DTYPE_C));
pipe.InitBuffer(tmpQueue, BUFFER_NUM, tmpsize);
AscendC::LocalTensor<DTYPE_X> srcLocal = inQueueX.DeQue<DTYPE_X>();
AscendC::LocalTensor<DTYPE_A> dstLocal = outQueueA.AllocTensor<DTYPE_A>();
AscendC::LocalTensor<DTYPE_B> maxLocal = outQueueB.AllocTensor<DTYPE_B>();
AscendC::LocalTensor<DTYPE_C> sumLocal = outQueueC.AllocTensor<DTYPE_C>();
AscendC::SoftMaxShapeInfo softmaxInfo = {outter, inner, outter, inner};
AscendC::LocalTensor<uint8_t> tmpLocal = tmpQueue.AllocTensor<uint8_t>();
AscendC::LogSoftMax<DTYPE_X, false>(dstLocal, sumLocal, maxLocal, srcLocal, tmpLocal, softmaxTiling, softmaxInfo);
Result example:
1
2
Input data (srcLocal): [0.80541134 0.08385705 0.49426016 ...  0.30962205 0.28947052]
Output data (dstLocal): [-0.6344272 -1.4868407 -1.0538127  ...  -1.2560008 -1.2771227]