ShapeInfo用来存放LocalTensor或GlobalTensor的shape信息。
1 2 3 4 5 6 7 8 9 10 11 12 13 | struct ShapeInfo { public: __aicore__ inline ShapeInfo(); __aicore__ inline ShapeInfo(const uint8_t inputShapeDim, const uint32_t inputShape[], const uint8_t inputOriginalShapeDim, const uint32_t inputOriginalShape[], const DataFormat inputFormat); __aicore__ inline ShapeInfo(const uint8_t inputShapeDim, const uint32_t inputShape[], const DataFormat inputFormat); __aicore__ inline ShapeInfo(const uint8_t inputShapeDim, const uint32_t inputShape[]); uint8_t shapeDim; uint8_t originalShapeDim; uint32_t shape[K_MAX_DIM]; uint32_t originalShape[K_MAX_DIM]; DataFormat dataFormat; }; |
1 | __aicore__ inline int GetShapeSize(const ShapeInfo& shapeInfo) |
参数名称 |
描述 |
---|---|
shapeDim |
现有的shape维度。 |
shape |
现有的shape。 |
originalShapeDim |
原始的shape维度。 |
originalShape |
原始的shape。 |
dataFormat |
数据排布格式,DataFormat类型,定义如下: enum class DataFormat : uint8_t { ND = 0, NZ, NCHW, NC1HWC0, NHWC, }; |
函数名称 |
入参说明 |
含义 |
---|---|---|
shapeInfo |
Tensor的shape信息 |
用来存放LocalTensor或GlobalTensor的shape信息。 |