开发流程
该开发流程以工程化算子开发为基础,除了需要提供工程化算子开发中的算子实现文件外,还需要额外交付算子入图的代码文件。本节仅提供算子入图代码文件的开发指导。
假设下图是我们需要使用的网络模型,您可能会想直接逐个算子调用,根据输入tensor得到输出tensor就可以完成网络的运行,但在图模式场景下,实际的网络模型生成过程中,会先进行tensor shape以及datatype的推导。这样可以让我们在图执行之前,就知道各tensor的数据类型和形状,提前校验其正确性;同时提前推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息,算子构图准备阶段就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。
下面的网络模型经过shape和datatype推导之后,可以得到灰色底纹框中的推导信息:
除了tiling实现外,算子入图时需要额外提供的实现代码有以下几种:
- datatype推导:根据算子的输入datatype、算子逻辑及算子属性等信息,推理出算子的输出张量datatype。
- shape推导:根据算子的输入shape、算子逻辑及算子属性等信息,推理出算子的输出张量shape。
- 声明数据依赖:部分算子在InferShape时,需要依赖某个输入的具体值才可以进行,这类算子被称为“数据依赖算子”,对应的输入被称为“数据依赖输入”。该类算子在注册时,需要声明其数据依赖输入。
下表列出了不同类型的算子对上述实现代码的要求。
分类 |
对入图实现代码的要求 |
|---|---|
根据输入shape可以推导出输出shape。 |
|
依赖输入的value才能推导出输出shape,即数据依赖算子。 如Reshape算子,依赖shape输入的value才能推导出输出shape。 |
|
实际开发时通过固定的datatype和shape推导原型实现推导函数,然后再通过SetInferShape、SetInferDataType接口来关联对应的shape推导函数,样例如下。
namespace ge {
static graphStatus InferShape(gert::InferShapeContext *context)
{
...
return GRAPH_SUCCESS;
}
static graphStatus InferDataType(gert::InferDataTypeContext *context)
{
...
return ge::GRAPH_SUCCESS;
}
} // namespace ge
namespace ops {
class AddCustom : public OpDef {
public:
AddCustom(const char* name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("z")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
// 根据用户的算子调用方式决定需不需要注册 图模式调用方式下需要
this->SetInferShape(ge::InferShape);
this->SetInferDataType(ge::InferDataType);
this->AICore()
.SetTiling(optiling::TilingFunc);
// 请替换为实际的昇腾AI处理器型号
this->AICore().AddConfig("ascendxxx");
}
};
OP_ADD(AddCustom);
} // namespace ops
datatype推导
以AddCustom算子为例,InferDataType的实现如下所示。该样例中输出tensor的数据类型与输入tensor的数据类型相同,所以直接将任意一个输入tensor的数据类型赋给输出tensor即可。
namespace ge {
static graphStatus InferDataType(gert::InferDataTypeContext* context)
{
const auto inputDataType = context->GetInputDataType(0);
context->SetOutputDataType(0, inputDataType);
return ge::GRAPH_SUCCESS;
}
} // namespace ge
如下示例则给出了更灵活的datatype推导样例,当输入的数据类型为DT_INT4时,其输出的数据类型为DT_INT32。
ge::graphStatus InferDataTypeForFoo(gert::InferDataTypeContext* context) {
if (context->GetInputDataType(0) == DT_INT4) {
context->SetOutputDataType(0, DT_INT32);
}
}
shape推导
InferShape函数的原型是确定的,其接受一个InferShapeContext作为输入,从此context上可以获取到输入、输出的shape指针等内容。InferShape成功后,返回ge::GRAPH_SUCCESS,其他返回值被认为推导失败。推导失败后,执行过程结束退出。
以AddCustom算子为例,InferShape的实现如下所示。该样例中输tensor的描述信息与输入tensor的描述信息相同,所以直接将任意一个输入tensor的描述赋给输出tensor即可。从如下代码可以看出,输入shape为const类型,因此InferShape时,输入shape是只读、不允许修改的。
namespace ge {
static graphStatus InferShape(gert::InferShapeContext* context)
{
const auto inputShape = context->GetInputShape(0);
auto outputShape = context->GetOutputShape(0);
*outputShape = *inputShape;
return GRAPH_SUCCESS;
}
} // namespace ge
InferShapeContextpublic继承自ExtendedKernelContext,因此ExtendedKernelContext中提供的方法如获取算子type、name、属性等接口均可以在InferShapeContext实例中调用。
为了效率考虑,调用InferShape函数时,框架不会为输出shape做初始化,因此,在InferShape函数中,可以认为输出是未初始化的状态。如果在InferShape时,希望通过Append方式操作输出shape,需要先将输出shape的DimNum清零,以防止出现未定义行为。
InferShape时获取属性、输入
在InferShape、Tiling时,可以通过context实例获取算子IR属性值,所谓IR属性,是指在IR注册时定义的属性,以TransData算子为例:
namespace ops {
class TransData : public OpDef {
public:
explicit TransData(const char *name) : OpDef(name)
{
this->Input("src")
...
this->Output("dst")
...
this->Attr("src_format")
.AttrType(REQUIRED)
.String();
this->Attr("dst_format")
.AttrType(REQUIRED)
.String();
this->Attr("group")
.AttrType(OPTIONAL)
.Int(1);
...
}
};
OP_ADD(TransData);
} // namespace ops
其原型定义中声明了src_format、dst_format、group三个属性,可以通过如下方式获取算子属性:
ge::graphStatus ExampleGetTransDataAttr(TilingContext *context) {
// 获取所有属性
const RuntimeAttrs *attrs = context->GetAttrs();
ASSERT_NOT_NULL(attrs);
// 按照在原型定义中的顺序,使用index获取属性,index从0开始计数
const char *src_format = attrs->GetAttrPointer<char>(0); // 获取src_format,src_format是第一个属性,因此index为0
const char *dst_format = attrs->GetAttrPointer<char>(1); // 获取dst_format,dst_format是第二个属性,因此index为1
const int64_t group = attrs->GetAttrPointer<int64_t>(2); // 获取group,group是第三个属性,因此index为2
return ge::GRAPH_SUCCESS;
}
通过index而不是字符串name来索引输入输出,对于带有OPTIONAL、DYNAMIC类型输入的算子,可能出现实例化后,单纯通过index无法索引到具体输入的问题,以DynamicRNNV3算子为例:
namespace ops {
class DynamicRNNV3 : public OpDef {
public:
explicit DynamicRNNV3(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
...
this->Input("w")
.ParamType(REQUIRED)
...
this->Input("b")
.ParamType(REQUIRED)
...
this->Input("seq_length")
.ParamType(OPTIONAL)
...
this->Input("init_h")
.ParamType(OPTIONAL)
...
this->Input("init_c")
.ParamType(OPTIONAL)
...
this->Input("wci")
.ParamType(OPTIONAL)
...
this->Input("wcf")
.ParamType(OPTIONAL)
...
this->Input("mask")
.ParamType(OPTIONAL)
...
this->Input("mask")
.ParamType(OPTIONAL)
...
this->Input("project")
.ParamType(OPTIONAL)
...
...
}
};
OP_ADD(DynamicRNNV3);
} // namespace ops
由于DynamicRNNV3算子有连续的多个optional输入,这导致init_h及其后面的输入的实例化后index都是不确定的,对于这种类型的算子,可以通过GetOptionalInputShape传入原型对应的index来获取对应的输入shape等数据,以InferShape为例:
ge::graphStatus InferShapeForDynamicRNNV3(InferShapeContext *context) {
// 对于前两个输入,不受到optional或dynamic的影响,可以按照常规方法获取输入shape
auto x_shape = context->GetInputShape(0);
auto w_shape = context->GetInputShape(1);
if (x_shape == nullptr || w_shape == nullptr) {
return ge::GRAPH_FAILED;
}
int64_t state_size = 0;
// 在原型定义上,project是第11个输入(从0开始计数)
constexpr int64_t kProjectInputIndex = 11;
// 受到前面optional输入影响的,project实例化后输入的index是不确定的,通过GetOptionalInputShape来获取对应的输入shape,
// GetOptionalInputShape的入参为原型上对应的index
auto project_shape = context->GetOptionalInputShape(kProjectInputIndex);
if (project_shape != nullptr) {
if (project_shape->GetDimNum() < 2) {
return ge::GRAPH_FAILED;
}
state_size = project_shape->GetDim(1);
}
// 更多的infershape逻辑...
return ge::GRAPH_SUCCESS;
}
对于dynamic类型的输入,实例化后的输入可能是一到多个,对于此类输入,获取方式为:
// ir_index:此输入在原型定义中的index,从0开始计数 // relative_index:该输入实例化后的相对index,从0开始计数,例如某个DYNAMIC_INPUT实例化了3个,要取第二个,那么relatvie_index = 1 auto shape = context->GetDynamicInputShape(ir_index, relative_index);
本节举例的获取optional、dynamic输入的方式,在InferShape、Tiling函数中均可以调用。
数据依赖
一般来说,具备输入shape后,算子可以通过InferShape推导出输出shape。然而部分算子在InferShape时,需要依赖某个输入的具体值才可以进行,这类算子被称为“数据依赖算子”,对应的输入被称为“数据依赖输入”。以Reshape算子为例,其依据shape输入的描述,对输入的shape做调整,因此Reshape算子依赖shape输入的值。这类算子需要在原型定义时通过ValueDepend接口声明对应的输入为数据依赖输入。
namespace ops {
class Reshape : public OpDef {
public:
explicit Reshape(const char *name) : OpDef(name)
{
...
this->Input("shape")
.ParamType(REQUIRED)
...
.ValueDepend(REQUIRED) // 声明 ReShape算子的shape输入为数据依赖输入
...
}
};
OP_ADD(Reshape);
} // namespace ops
根据第1个输入(shape输入)的值,Reshape算子将第0个输入(x输入)的shape做变换,并输出到其第0个输出(y输出)上。Reshape的InferShape实现为:
ge::graphStatus InferShapeForReshape(InferShapeContext *context) {
const gert::Shape *x_shape = context->GetInputShape(0); // 获取第0个输入的shape
const gert::Tensor *shape_tensor = context->GetInputTensor(1); // 获取第1个输入的tensor
gert::Shape *output_shape = context->GetOutputShape(0);
if (x_shape == nullptr || shape_tensor == nullptr || output_shape == nullptr) {
// 防御式编程,不应该出现的场景,打印错误并返回失败
return ge::GRAPH_FAILED;
}
auto reshape_size = static_cast<int32_t>(shape_tensor->GetShapeSize());
if (reshape_size < 1) {
// 防御式编程,不应该出现的场景,打印错误并返回失败
return ge::GRAPH_FAILED;
}
// 根据原型信息,Reshape的shape输入支持INT32与INT64两类,根据不同的类型进入对应的模板函数中做真正的shape变换操作
if (shape_tensor->GetDataType() == ge::DT_INT32) {
int32_t *reshape_data = shape_tensor->GetData<int32_t>();
return ReshapeInferShapeImpl<int32_t>(reshape_data, *x_shape, *output_shape, reshape_size);
} else {
int64_t *reshape_data = shape_tensor->GetData<int64_t>();
return ReshapeInferShapeImpl<int64_t>(reshape_data, *x_shape, *output_shape, reshape_size);
}
}
- 只有声明过数据依赖的输入,才可以在InferShape时调用GetInputTensor等获取tensor的接口获取其对应的tensor数据。若对一个未声明数据依赖的输入调用GetInputTensor等获取tensor的接口,只能在tensor中获取到正确的shape、format、datatype信息,,无法获取到真实的tensor数据地址(获取到的地址为nullptr)。
- 从tensor中获取tensor_data时(GetData<int32_t>或GetData<int64_t>),使用者需要保证获取的数据类型是正确的,否则行为是未定义的。