ShapeInfo用来存放LocalTensor或GlobalTensor的shape信息。
struct ShapeInfo { public: __aicore__ inline ShapeInfo() {} __aicore__ inline ShapeInfo(const uint8_t inputShapeDim, const uint32_t inputShape[], const uint8_t inputOriginalShapeDim, const uint32_t inputOriginalShape[], const DataFormat inputFormat) : shapeDim(inputShapeDim), originalShapeDim(inputOriginalShapeDim), dataFormat(inputFormat) { ASCENDC_ASSERT((inputShapeDim <= K_MAX_DIM && inputOriginalShapeDim <= K_MAX_DIM), { KERNEL_LOG(KERNEL_ERROR, "inputShapeDim is %d, inputOriginalShapeDim is %d, which should be less than %d both", inputShapeDim, inputOriginalShapeDim, K_MAX_DIM); }); for (int index = 0; index < shapeDim; ++index) { shape[index] = inputShape[index]; } for (int index = 0; index < originalShapeDim; ++index) { originalShape[index] = inputOriginalShape[index]; } } __aicore__ inline ShapeInfo(const uint8_t inputShapeDim, const uint32_t inputShape[], const DataFormat inputFormat) : shapeDim(inputShapeDim), originalShapeDim(inputShapeDim), dataFormat(inputFormat) { ASCENDC_ASSERT((inputShapeDim <= K_MAX_DIM), { KERNEL_LOG(KERNEL_ERROR, "inputShapeDim is %u, which should be less than %d", inputShapeDim, K_MAX_DIM); }); for (int index = 0; index < shapeDim; ++index) { shape[index] = inputShape[index]; originalShape[index] = inputShape[index]; } } __aicore__ inline ShapeInfo(const uint8_t inputShapeDim, const uint32_t inputShape[]) : shapeDim(inputShapeDim), originalShapeDim(inputShapeDim), dataFormat(DataFormat::ND) { ASCENDC_ASSERT((inputShapeDim <= K_MAX_DIM), { KERNEL_LOG(KERNEL_ERROR, "inputShapeDim is %d, which should be less than %d", inputShapeDim, K_MAX_DIM); }); for (int index = 0; index < shapeDim; ++index) { shape[index] = inputShape[index]; originalShape[index] = inputShape[index]; } } uint8_t shapeDim; uint8_t originalShapeDim; uint32_t shape[K_MAX_DIM]; uint32_t originalShape[K_MAX_DIM]; DataFormat dataFormat; };
参数名称 |
类型 |
描述 |
---|---|---|
shapeDim |
uint8_t |
现有的shape维度 |
shape |
uint32_t |
现有的shape |
originalShapeDim |
uint8_t |
原始的shape维度 |
originalShape |
uint32_t |
原始的shape |
dataFormat |
DataFormat |
数据排布格式 enum DataFormat { NCHW = 0; NHWC = 1; } NCHW:数据按NCHW排布 NHWC:数据按NHWC排布 |
函数名称 |
入参说明 |
含义 |
---|---|---|
shapeInfo |
Tensor的shape信息 |
用来存放LocalTensor或GlobalTensor的shape信息。 |