老框架样例模型后处理代码
后处理头文件(ResNet50PostProcessor.h)
/*
* Copyright (c) 2023. Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXPLUGINS_RESNET50POSTPROCESSOR_H
#define MXPLUGINS_RESNET50POSTPROCESSOR_H
#include "MxPlugins/ModelPostProcessors/ModelPostProcessorBase/MxpiModelPostProcessorBase.h"
class ResNet50PostProcessor : public MxPlugins::MxpiModelPostProcessorBase {
public:
APP_ERROR Init(const std::string& configPath, const std::string& labelPath, MxBase::ModelDesc modelDesc) override;
APP_ERROR DeInit();
APP_ERROR Process(std::shared_ptr<void>& metaDataPtr, MxBase::PostProcessorImageInfo postProcessorImageInfo,
std::vector<MxTools::MxpiMetaHeader>& headerVec, std::vector<std::vector<MxBase::BaseTensor>>& tensors);
private:
APP_ERROR CheckModelCompatibility();
int classNum_;
};
extern "C" {
std::shared_ptr<MxPlugins::MxpiModelPostProcessorBase> GetInstance();
}
#endif // RESNET50POSTPROCESSOR_H后处理源文件(ResNet50PostProcessor.cpp)
/*
* Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ResNet50PostProcessor.h"
using namespace MxBase;
using namespace MxTools;
namespace {
const int MAX_POOLNUM = 2;
}
APP_ERROR ResNet50PostProcessor::Init(const std::string& configPath, const std::string& labelPath,
MxBase::ModelDesc modelDesc)
{
LogInfo << "Begin to initialize ResNet50PostProcessor.";
APP_ERROR ret = LoadConfigDataAndLabelMap(configPath, labelPath);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Fail to superInit in ResNet50PostProcessor.";
return ret;
}
configData_.GetFileValue<int>("CLASS_NUM", classNum_);
GetModelTensorsShape(modelDesc);
if (checkModelFlag_) {
ret = CheckModelCompatibility();
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Fail to CheckModelCompatibility in ResNet50PostProcessor."
<< "Please check the compatibility between model and postprocessor";
return ret;
}
} else {
LogWarn << "Compatibility check for model is skipped as CHECK_MODEL is set as false, please ensure your model"
<< "is correct before running.";
}
LogInfo << "End to initialize ResNet50PostProcessor.";
return APP_ERR_OK;
}
APP_ERROR ResNet50PostProcessor::DeInit()
{
LogInfo << "Begin to deinitialize ResNet50PostProcessor.";
LogInfo << "End to deinitialize ResNet50PostProcessor.";
return APP_ERR_OK;
}
APP_ERROR ResNet50PostProcessor::Process(std::shared_ptr<void>& metaDataPtr, MxBase::PostProcessorImageInfo postProcessorImageInfo,
std::vector<MxTools::MxpiMetaHeader>& headerVec, std::vector<std::vector<MxBase::BaseTensor>>& tensors)
{
LogDebug << "Begin to process ResNet50PostProcessor.";
APP_ERROR ret;
if (metaDataPtr == nullptr) {
metaDataPtr = std::static_pointer_cast<void>(std::make_shared<MxTools::MxpiClassList>());
}
std::shared_ptr<MxTools::MxpiClassList> classList = std::static_pointer_cast<MxTools::MxpiClassList>(metaDataPtr);
for (unsigned int i = 0; i < tensors.size(); i++) {
auto featLayerData = std::vector<std::shared_ptr<void>>();
// Copy the inferred results data back to Host and do argmax for labeling.
ret = MemoryDataToHost(i, tensors, featLayerData);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Fail to copy device memory to host for ResNet50PostProcessor.";
return ret;
}
float *castData = static_cast<float *>(featLayerData[0].get());
std::vector<float> result;
for (int j = 0; j < classNum_; ++j) {
result.push_back(castData[j]);
}
std::vector<float>::iterator maxElement = std::max_element(std::begin(result), std::end(result));
size_t argmaxIndex = maxElement - std::begin(result);
MxpiClass* mxpiClass = classList->add_classvec();
mxpiClass->set_classid(argmaxIndex);
mxpiClass->set_classname(configData_.GetClassName(argmaxIndex));
mxpiClass->set_confidence(*maxElement);
MxpiMetaHeader* mxpiMetaHeader = mxpiClass->add_headervec();
mxpiMetaHeader->set_memberid(headerVec[i].memberid());
mxpiMetaHeader->set_datasource(headerVec[i].datasource());
LogDebug << "class Id and name of the most possible class: " << argmaxIndex << ", "
<< configData_.GetClassName(argmaxIndex);
}
LogDebug << "End to process ResNet50PostProcessor.";
return APP_ERR_OK;
}
APP_ERROR ResNet50PostProcessor::CheckModelCompatibility()
{
if (outputTensorShapes_.size() > MAX_POOLNUM) {
LogError << "outputTensorShapes_.size() > 2.";
return APP_ERR_OUTPUT_NOT_MATCH;
}
if (outputTensorShapes_[0][0] != modelDesc_.batchSizes.back()) {
LogError << "outputTensorShapes_[0][0] != modelDesc_.batchSizes.back().";
return APP_ERR_OUTPUT_NOT_MATCH;
}
if (outputTensorShapes_[0][1] < classNum_) {
LogError << "outputTensorShapes_[0][1] < classNum_(" << classNum_ << ").";
return APP_ERR_OUTPUT_NOT_MATCH;
}
return APP_ERR_OK;
}
std::shared_ptr<MxPlugins::MxpiModelPostProcessorBase> GetInstance()
{
LogInfo << "Begin to get ResNet50PostProcessor instance.";
auto instance = std::make_shared<ResNet50PostProcessor>();
LogInfo << "End to get ResNet50PostProcessor instance.";
return instance;
}
CMakeLists
cmake_minimum_required(VERSION 3.5.2)
project(resnet50postsample)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-Dgoogle=mindxsdk_private)
set(PLUGIN_NAME "resnet50postsample")
set(TARGET_LIBRARY ${PLUGIN_NAME})
include_directories(${PROJECT_SOURCE_DIR}/../../include)
include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include)
include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include/gstreamer-1.0)
include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include/glib-2.0)
include_directories(${PROJECT_SOURCE_DIR}/../../opensource/lib/glib-2.0/include)
link_directories(${PROJECT_SOURCE_DIR}/../../opensource/lib/)
link_directories(${PROJECT_SOURCE_DIR}/../../lib)
add_compile_options(-std=c++11 -fPIC -fstack-protector-all -pie -Wno-deprecated-declarations)
add_compile_options("-DPLUGIN_NAME=${PLUGIN_NAME}")
add_definitions(-DENABLE_DVPP_INTERFACE)
add_library(${TARGET_LIBRARY} SHARED ResNet50PostProcessor.cpp)
target_link_libraries(${TARGET_LIBRARY} glib-2.0 gstreamer-1.0 gobject-2.0 gstbase-1.0 gmodule-2.0)
target_link_libraries(${TARGET_LIBRARY} plugintoolkit mxpidatatype mxbase)
target_link_libraries(${TARGET_LIBRARY} -Wl,-z,relro,-z,now,-z,noexecstack -s)
install(TARGETS ${TARGET_LIBRARY} PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ LIBRARY DESTINATION ${PROJECT_SOURCE_DIR}/../../lib)
父主题: 代码参考