文档
注册

max_pooling3d_grad_grad

功能说明

计算maxpooling3d的二阶梯度。

如下所示:

input_d = 4, input_h =4, input_w = 4

stride_d = 2, stride_h = 2, stride_w =2

kernel_d = 2, kernel_h = 2, kernel_w = 2

计算如下:

  • input_d:orig_in中的D
  • input_h:orig_in的H。
  • input_w:orig_in中的W
  • kernel_d:ksize的D
  • kernel_w:ksize的W
  • kernel_h:ksize的H
  • stride_d:strides中的D
  • stride_h:strides中的H
  • stride_w:strides中的W
  • pad_top:orig_in的D方向顶上的pad的行数,图示中此值为0
  • pad_bottom:orig_in的D方向底部的pad的行数,图示中此值为0
  • pad_front:orig_in的H方向前面的pad的行数,图示中此值为0
  • pad_back:orig_in的H方向后面的pad的行数,图示中此值为0
  • pad_left:orig_in的W方向上左边的pad列数,图示中此值为0
  • pad_right:orig_in的W方向行右边的pad列数,图示中此值为0

函数原型

max_pooling3d_grad_grad(orig_input, orig_output, grad_grad, assist_tensor, ksize, strides, pads=(0, 0, 0, 0, 0, 0), data_format="NDHWC", padding="SAME")

参数说明

  • orig_input:输入的feature map,tvm.tensor类型。符合6D-NDC1HWC0格式排布的tensor。
  • orig_output:maxpooling3d的输出结果,tvm.tensor类型。符合6D-NDC1HWC0格式排布的tensor。
  • grad_grad: 二阶梯度值,tvm.tensor类型。符合6D-NDC1HWC0格式排布的tensor。
  • assist_tensor: 融合规则自动构造的辅助矩阵,用于消除最大值的重复场景。大小依赖于ksize的大小,例如在ksize为2x2x2时,则其取值为[8,7,6...2,1]。
  • ksize:输入的滑块大小信息,list、tuple类型。ksize[0]表示输入window的depth,ksize[1]表示输入window的width,ksize[2]输入window的height。
  • strides:输入的滑块移动步长信息,list、tuple类型。stride[0]表示window在feature map的D方向上移动的步长,stride[1]表示window在feature map的W方向上移动的步长,stride[2]表示window在feature map的H方向上移动的步长。
  • pads:补pad的数目,list、tuple类型。可选参数,用于兼容Caffe的pooling。pads[0], pads[1], pads[2], pads[3] ,pads[4], pads[5] 分别代表用户输入的在top, bottom, front, back, left, right方向补的pad,默认值为(0,0,0,0,0,0)。当pads中有非0值时,参考SAME模式计算补padding的数据;当pads中的数据为全0时,参考VALID模式。
  • data_format:数据格式。
  • padding:padding模式,支持“VALID”、“SAME”,分别代表 不补pad、补pad。

返回值

res_tensor:输出tensor,tvm.tensor类型,为符合6D-NDC1HWC0格式排布的tensor。

将tensor_in 的shape信息记为[N, D, C1, H, W, C0=16],window 的shape信息记为 [F, F],stride 信息记为 [S, S],则:

MAX and AVG 的VALID模式与SAME模式下输出tensor的shape信息计算方式分别如下所示:

  • VALID模式下
    • N 和C 维度保持不变。
    • Dout, Hout 和 Wout 维度大小为:

      new_depth=new_height=new_width = CEIL(W-F+1/S)

  • SAME模式下
    • N 和C 维度保持不变。
    • Dout, Hout 和Wout 维度大小为:

      new_depth=new_height=new_width = CEIL(W/S)

      其中,W为输入size,F为filter的size,S为步长,[]为向上取整符号。

约束说明

此接口暂不支持与其他TBE DSL计算接口混合使用。

该接口不支持出口量化功能。

支持的型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

Atlas 推理系列产品(Ascend 310P处理器)

Atlas 200/500 A2推理产品

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

from tbe import tvm
from tbe import dsl
shape_in = (1, 416, 2, 416, 416, 16) 
shape_out = (1, 208, 2, 208, 208, 16) 
shape_ksize = (3, 3, 3)
input_dtype = "float16"
orig_in = tvm.placeholder(shape_in, name="orig_in", dtype=input_dtype) 
orig_out = tvm.placeholder(shape_out, name="orig_out", dtype=input_dtype)
grad_grad = tvm.placeholder(shape_in, name="grad_grad", dtype=input_dtype)  
assist_tensor = tvm.placeholder(shape_in, name="assist_tensor", dtype=input_dtype)
res = dsl.max_pooling3d_grad_grad(orig_in, orig_out, grad_grad, assist_tensor, (3, 3, 3), (2, 2, 2), (0, 0, 0, 0, 0, 0), "NDHWC")
# res.shape = (1, 208, 2, 208, 208, 16)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词