TensorFormat
TBE提供了TensorFormat结构体用于定义算子输入输出数据的数据排布格式。
定义如下所示:
class TensorFormat:
ND = "ND"
NCHW = "NCHW"
NHWC = "NHWC"
NDHWC = "NDHWC"
NCDHW = "NCDHW"
CHWN = "CHWN"
NC1HWC0 = "NC1HWC0"
NC1HWC0_C04 = "NC1HWC0_C04"
NDC1HWC0 = "NDC1HWC0"
FRACTAL_NZ = "FRACTAL_NZ"
HWCN = "HWCN"
DHWCN = "DHWCN"
FRACTAL_Z = "FRACTAL_Z"
FRACTAL_Z_C04 = "FRACTAL_Z_C04"
C1HWNCoC0 = "C1HWNCoC0"
FRACTAL_Z_3D = "FRACTAL_Z_3D"
FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"
ND_RNN_BIAS = "ND_RNN_BIAS"
使用示例:
from tbe.common.utils.para_check import OpParamInfoKey, TensorFormat if x.get(OpParamInfoKey.FORMAT) == TensorFormat.NCHW # do something
父主题: 数据结构定义