老框架样例模型后处理代码
后处理头文件(ResNet50PostProcessor.h)
/* * Copyright (c) 2023. Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MXPLUGINS_RESNET50POSTPROCESSOR_H #define MXPLUGINS_RESNET50POSTPROCESSOR_H #include "MxPlugins/ModelPostProcessors/ModelPostProcessorBase/MxpiModelPostProcessorBase.h" class ResNet50PostProcessor : public MxPlugins::MxpiModelPostProcessorBase { public: APP_ERROR Init(const std::string& configPath, const std::string& labelPath, MxBase::ModelDesc modelDesc) override; APP_ERROR DeInit(); APP_ERROR Process(std::shared_ptr<void>& metaDataPtr, MxBase::PostProcessorImageInfo postProcessorImageInfo, std::vector<MxTools::MxpiMetaHeader>& headerVec, std::vector<std::vector<MxBase::BaseTensor>>& tensors); private: APP_ERROR CheckModelCompatibility(); int classNum_; }; extern "C" { std::shared_ptr<MxPlugins::MxpiModelPostProcessorBase> GetInstance(); } #endif // RESNET50POSTPROCESSOR_H
后处理源文件(ResNet50PostProcessor.cpp)
/* * Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ResNet50PostProcessor.h" using namespace MxBase; using namespace MxTools; namespace { const int MAX_POOLNUM = 2; } APP_ERROR ResNet50PostProcessor::Init(const std::string& configPath, const std::string& labelPath, MxBase::ModelDesc modelDesc) { LogInfo << "Begin to initialize ResNet50PostProcessor."; APP_ERROR ret = LoadConfigDataAndLabelMap(configPath, labelPath); if (ret != APP_ERR_OK) { LogError << GetError(ret) << "Fail to superInit in ResNet50PostProcessor."; return ret; } configData_.GetFileValue<int>("CLASS_NUM", classNum_); GetModelTensorsShape(modelDesc); if (checkModelFlag_) { ret = CheckModelCompatibility(); if (ret != APP_ERR_OK) { LogError << GetError(ret) << "Fail to CheckModelCompatibility in ResNet50PostProcessor." << "Please check the compatibility between model and postprocessor"; return ret; } } else { LogWarn << "Compatibility check for model is skipped as CHECK_MODEL is set as false, please ensure your model" << "is correct before running."; } LogInfo << "End to initialize ResNet50PostProcessor."; return APP_ERR_OK; } APP_ERROR ResNet50PostProcessor::DeInit() { LogInfo << "Begin to deinitialize ResNet50PostProcessor."; LogInfo << "End to deinitialize ResNet50PostProcessor."; return APP_ERR_OK; } APP_ERROR ResNet50PostProcessor::Process(std::shared_ptr<void>& metaDataPtr, MxBase::PostProcessorImageInfo postProcessorImageInfo, std::vector<MxTools::MxpiMetaHeader>& headerVec, std::vector<std::vector<MxBase::BaseTensor>>& tensors) { LogDebug << "Begin to process ResNet50PostProcessor."; APP_ERROR ret; if (metaDataPtr == nullptr) { metaDataPtr = std::static_pointer_cast<void>(std::make_shared<MxTools::MxpiClassList>()); } std::shared_ptr<MxTools::MxpiClassList> classList = std::static_pointer_cast<MxTools::MxpiClassList>(metaDataPtr); for (unsigned int i = 0; i < tensors.size(); i++) { auto featLayerData = std::vector<std::shared_ptr<void>>(); // Copy the inferred results data back to Host and do argmax for labeling. ret = MemoryDataToHost(i, tensors, featLayerData); if (ret != APP_ERR_OK) { LogError << GetError(ret) << "Fail to copy device memory to host for ResNet50PostProcessor."; return ret; } float *castData = static_cast<float *>(featLayerData[0].get()); std::vector<float> result; for (int j = 0; j < classNum_; ++j) { result.push_back(castData[j]); } std::vector<float>::iterator maxElement = std::max_element(std::begin(result), std::end(result)); size_t argmaxIndex = maxElement - std::begin(result); MxpiClass* mxpiClass = classList->add_classvec(); mxpiClass->set_classid(argmaxIndex); mxpiClass->set_classname(configData_.GetClassName(argmaxIndex)); mxpiClass->set_confidence(*maxElement); MxpiMetaHeader* mxpiMetaHeader = mxpiClass->add_headervec(); mxpiMetaHeader->set_memberid(headerVec[i].memberid()); mxpiMetaHeader->set_datasource(headerVec[i].datasource()); LogDebug << "class Id and name of the most possible class: " << argmaxIndex << ", " << configData_.GetClassName(argmaxIndex); } LogDebug << "End to process ResNet50PostProcessor."; return APP_ERR_OK; } APP_ERROR ResNet50PostProcessor::CheckModelCompatibility() { if (outputTensorShapes_.size() > MAX_POOLNUM) { LogError << "outputTensorShapes_.size() > 2."; return APP_ERR_OUTPUT_NOT_MATCH; } if (outputTensorShapes_[0][0] != modelDesc_.batchSizes.back()) { LogError << "outputTensorShapes_[0][0] != modelDesc_.batchSizes.back()."; return APP_ERR_OUTPUT_NOT_MATCH; } if (outputTensorShapes_[0][1] < classNum_) { LogError << "outputTensorShapes_[0][1] < classNum_(" << classNum_ << ")."; return APP_ERR_OUTPUT_NOT_MATCH; } return APP_ERR_OK; } std::shared_ptr<MxPlugins::MxpiModelPostProcessorBase> GetInstance() { LogInfo << "Begin to get ResNet50PostProcessor instance."; auto instance = std::make_shared<ResNet50PostProcessor>(); LogInfo << "End to get ResNet50PostProcessor instance."; return instance; }
CMakeLists
cmake_minimum_required(VERSION 3.5.2) project(resnet50postsample) add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) add_definitions(-Dgoogle=mindxsdk_private) set(PLUGIN_NAME "resnet50postsample") set(TARGET_LIBRARY ${PLUGIN_NAME}) include_directories(${PROJECT_SOURCE_DIR}/../../include) include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include) include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include/gstreamer-1.0) include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include/glib-2.0) include_directories(${PROJECT_SOURCE_DIR}/../../opensource/lib/glib-2.0/include) link_directories(${PROJECT_SOURCE_DIR}/../../opensource/lib/) link_directories(${PROJECT_SOURCE_DIR}/../../lib) add_compile_options(-std=c++11 -fPIC -fstack-protector-all -pie -Wno-deprecated-declarations) add_compile_options("-DPLUGIN_NAME=${PLUGIN_NAME}") add_definitions(-DENABLE_DVPP_INTERFACE) add_library(${TARGET_LIBRARY} SHARED ResNet50PostProcessor.cpp) target_link_libraries(${TARGET_LIBRARY} glib-2.0 gstreamer-1.0 gobject-2.0 gstbase-1.0 gmodule-2.0) target_link_libraries(${TARGET_LIBRARY} plugintoolkit mxpidatatype mxbase) target_link_libraries(${TARGET_LIBRARY} -Wl,-z,relro,-z,now,-z,noexecstack -s) install(TARGETS ${TARGET_LIBRARY} PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ LIBRARY DESTINATION ${PROJECT_SOURCE_DIR}/../../lib)
父主题: 代码参考