在核函数中判断本次执行时的tiling_key是否等于host侧运行时设置的某个key,从而标识tiling_key==key的一条kernel分支。
1 | TILING_KEY_IS(key) |
参数 |
输入/输出 |
说明 |
---|---|---|
key |
输入 |
此参数是正数,表示某个核函数的分支。 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | extern "C" __global__ __aicore__ void add_custom(__gm__ uint8_t *x, __gm__ uint8_t *y, __gm__ uint8_t *z, __gm__ uint8_t *workspace, __gm__ uint8_t *tiling) { GET_TILING_DATA(tilingData, tiling); if (workspace == nullptr) { return; } KernelAdd op; op.Init(x, y, z, tilingData.blockDim, tilingData.totalLength, tilingData.tileNum); // 当TilingKey为1时,执行Process1;为2时,执行Process2;为3时,执行Process3 if (TILING_KEY_IS(1)) { op.Process1(); } else if (TILING_KEY_IS(2)) { op.Process2(); } else if (TILING_KEY_IS(3)) { op.Process3(); } // 其他代码逻辑 ... // 此处示例当TilingKey为3时,会执行ProcessOther if (TILING_KEY_IS(3)) { op.ProcessOther(); } } |
配套的host侧tiling函数示例(伪代码):
1 2 3 4 5 6 7 8 9 10 11 12 | ge::graphStatus TilingFunc(gert::TilingContext* context) { // 其他代码逻辑 ... if (context->GetInputShape(0) > 10) { context->SetTilingKey(1); } else if (some condition) { context->SetTilingKey(2); } else if (some condition) { context->SetTilingKey(3); } } |