昇腾社区首页
中文
注册

老框架样例模型后处理代码

后处理头文件(ResNet50PostProcessor.h)

/*
 * Copyright (c) 2023. Huawei Technologies Co., Ltd. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MXPLUGINS_RESNET50POSTPROCESSOR_H
#define MXPLUGINS_RESNET50POSTPROCESSOR_H

#include "MxPlugins/ModelPostProcessors/ModelPostProcessorBase/MxpiModelPostProcessorBase.h"

class ResNet50PostProcessor : public MxPlugins::MxpiModelPostProcessorBase {
public:
    APP_ERROR Init(const std::string& configPath, const std::string& labelPath, MxBase::ModelDesc modelDesc) override;
    APP_ERROR DeInit();
    APP_ERROR Process(std::shared_ptr<void>& metaDataPtr, MxBase::PostProcessorImageInfo postProcessorImageInfo,
        std::vector<MxTools::MxpiMetaHeader>& headerVec, std::vector<std::vector<MxBase::BaseTensor>>& tensors);

private:
    APP_ERROR CheckModelCompatibility();
    int classNum_;
};

extern "C" {
    std::shared_ptr<MxPlugins::MxpiModelPostProcessorBase> GetInstance();
}
#endif // RESNET50POSTPROCESSOR_H

后处理源文件(ResNet50PostProcessor.cpp)

/*
 * Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "ResNet50PostProcessor.h"

using namespace MxBase;
using namespace MxTools;

namespace {
const int MAX_POOLNUM = 2;
}

APP_ERROR ResNet50PostProcessor::Init(const std::string& configPath, const std::string& labelPath,
    MxBase::ModelDesc modelDesc)
{
    LogInfo << "Begin to initialize ResNet50PostProcessor.";
    APP_ERROR ret = LoadConfigDataAndLabelMap(configPath, labelPath);
    if (ret != APP_ERR_OK) {
        LogError << GetError(ret) << "Fail to superInit in ResNet50PostProcessor.";
        return ret;
    }
    configData_.GetFileValue<int>("CLASS_NUM", classNum_);
    GetModelTensorsShape(modelDesc);
    if (checkModelFlag_) {
        ret = CheckModelCompatibility();
        if (ret != APP_ERR_OK) {
            LogError << GetError(ret) << "Fail to CheckModelCompatibility in ResNet50PostProcessor."
                     << "Please check the compatibility between model and postprocessor";
            return ret;
        }
    } else {
        LogWarn << "Compatibility check for model is skipped as CHECK_MODEL is set as false, please ensure your model"
                << "is correct before running.";
    }
    LogInfo << "End to initialize ResNet50PostProcessor.";
    return APP_ERR_OK;
}

APP_ERROR ResNet50PostProcessor::DeInit()
{
    LogInfo << "Begin to deinitialize ResNet50PostProcessor.";
    LogInfo << "End to deinitialize ResNet50PostProcessor.";
    return APP_ERR_OK;
}

APP_ERROR ResNet50PostProcessor::Process(std::shared_ptr<void>& metaDataPtr, MxBase::PostProcessorImageInfo postProcessorImageInfo,
        std::vector<MxTools::MxpiMetaHeader>& headerVec, std::vector<std::vector<MxBase::BaseTensor>>& tensors)
{
    LogDebug << "Begin to process ResNet50PostProcessor.";
    APP_ERROR ret;
    if (metaDataPtr == nullptr) {
        metaDataPtr = std::static_pointer_cast<void>(std::make_shared<MxTools::MxpiClassList>());
    }
    std::shared_ptr<MxTools::MxpiClassList> classList = std::static_pointer_cast<MxTools::MxpiClassList>(metaDataPtr);
    for (unsigned int i = 0; i < tensors.size(); i++) {
        auto featLayerData = std::vector<std::shared_ptr<void>>();
        // Copy the inferred results data back to Host and do argmax for labeling.
        ret = MemoryDataToHost(i, tensors, featLayerData);
        if (ret != APP_ERR_OK) {
            LogError << GetError(ret) << "Fail to copy device memory to host for ResNet50PostProcessor.";
            return ret;
        }
        float *castData = static_cast<float *>(featLayerData[0].get());
        std::vector<float> result;
        for (int j = 0; j < classNum_; ++j) {
            result.push_back(castData[j]);
        }
        std::vector<float>::iterator maxElement = std::max_element(std::begin(result), std::end(result));
        size_t argmaxIndex = maxElement - std::begin(result);
        MxpiClass* mxpiClass = classList->add_classvec();
        mxpiClass->set_classid(argmaxIndex);
        mxpiClass->set_classname(configData_.GetClassName(argmaxIndex));
        mxpiClass->set_confidence(*maxElement);        
		MxpiMetaHeader* mxpiMetaHeader = mxpiClass->add_headervec();
		mxpiMetaHeader->set_memberid(headerVec[i].memberid());
        mxpiMetaHeader->set_datasource(headerVec[i].datasource());
        LogDebug << "class Id and name of the most possible class: " << argmaxIndex << ", "
                 << configData_.GetClassName(argmaxIndex);
    }
    LogDebug << "End to process ResNet50PostProcessor.";
    return APP_ERR_OK;
}

APP_ERROR ResNet50PostProcessor::CheckModelCompatibility()
{
    if (outputTensorShapes_.size() > MAX_POOLNUM) {
        LogError << "outputTensorShapes_.size() > 2.";
        return APP_ERR_OUTPUT_NOT_MATCH;
    }
    if (outputTensorShapes_[0][0] != modelDesc_.batchSizes.back()) {
        LogError << "outputTensorShapes_[0][0] != modelDesc_.batchSizes.back().";
        return APP_ERR_OUTPUT_NOT_MATCH;
    }
    if (outputTensorShapes_[0][1] < classNum_) {
        LogError << "outputTensorShapes_[0][1] < classNum_(" << classNum_ << ").";
        return APP_ERR_OUTPUT_NOT_MATCH;
    }
    return APP_ERR_OK;
}

std::shared_ptr<MxPlugins::MxpiModelPostProcessorBase> GetInstance()
{
    LogInfo << "Begin to get ResNet50PostProcessor instance.";
    auto instance = std::make_shared<ResNet50PostProcessor>();
    LogInfo << "End to get ResNet50PostProcessor instance.";
    return instance;
}

CMakeLists

cmake_minimum_required(VERSION 3.5.2)
project(resnet50postsample)

add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-Dgoogle=mindxsdk_private)

set(PLUGIN_NAME "resnet50postsample")
set(TARGET_LIBRARY ${PLUGIN_NAME})

include_directories(${PROJECT_SOURCE_DIR}/../../include)
include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include)
include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include/gstreamer-1.0)
include_directories(${PROJECT_SOURCE_DIR}/../../opensource/include/glib-2.0)
include_directories(${PROJECT_SOURCE_DIR}/../../opensource/lib/glib-2.0/include)

link_directories(${PROJECT_SOURCE_DIR}/../../opensource/lib/)
link_directories(${PROJECT_SOURCE_DIR}/../../lib)

add_compile_options(-std=c++11 -fPIC -fstack-protector-all -pie -Wno-deprecated-declarations)
add_compile_options("-DPLUGIN_NAME=${PLUGIN_NAME}")

add_definitions(-DENABLE_DVPP_INTERFACE)
add_library(${TARGET_LIBRARY} SHARED ResNet50PostProcessor.cpp)

target_link_libraries(${TARGET_LIBRARY} glib-2.0 gstreamer-1.0 gobject-2.0 gstbase-1.0 gmodule-2.0)
target_link_libraries(${TARGET_LIBRARY} plugintoolkit mxpidatatype mxbase)
target_link_libraries(${TARGET_LIBRARY} -Wl,-z,relro,-z,now,-z,noexecstack -s)

install(TARGETS ${TARGET_LIBRARY} PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ LIBRARY DESTINATION ${PROJECT_SOURCE_DIR}/../../lib)