c++多类别nms

改之前,不区分类别的:

cv::dnn::NMSBoxes(boxes, confidences, conf_threshold, nms_threshold, indices);

改完,类内做nms,类别间不做nms

inline cv::Rect2d RemapBoxOnSrc(const cv::Rect2d &box, const int img_width, const int img_height)
{
    float xmin = static_cast<float>(box.x);
    float ymin = static_cast<float>(box.y);
    float xmax = xmin + static_cast<float>(box.width);
    float ymax = ymin + static_cast<float>(box.height);
    cv::Rect2d remap_box;
    remap_box.x = std::max(.0f, xmin);
    remap_box.width = std::min(img_width - 1.0f, xmax) - remap_box.x;
    remap_box.y = std::max(.0f, ymin);
    remap_box.height = std::min(img_height - 1.0f, ymax) - remap_box.y;
    return remap_box;
};

inline float intersectionArea(const bbox& box1, const bbox& box2){
    float x1 = std::max(box1.left, box2.left);
    float y1 = std::max(box1.top, box2.top);
    float x2 = std::min(box1.left + box1.width, box2.left + box2.width);
    float y2 = std::min(box1.top + box1.height, box2.top + box2.height);
    if(x1 < x2 && y1 < y2) return (x2 - x1) * (y2 - y1);
    else return 0.0f;
};
 
// 计算矩形框的面积
inline float boxArea(const bbox& box){
    return box.width * box.height;
};

inline void NMSBoxes(std::vector<bbox> input_boxes, const int conf_threshold, const float nms_threshold, std::vector<bbox> & nms_boxes)
{
    std::vector<bbox> sortedBoxes = input_boxes;
    printf("confidence=%f, cls=%d\n", sortedBoxes[0].confidence, sortedBoxes[0].cls_id);
    std::sort(sortedBoxes.begin(), sortedBoxes.end(), [](bbox a, bbox b) {return a.confidence > b.confidence; });
    printf("confidence=%f, cls=%d\n", sortedBoxes[0].confidence, sortedBoxes[0].cls_id);
    
    while(!sortedBoxes.empty()){
        // 每次取分数最高的合法锚框
        printf("ccccccccccccccc\n");
        const bbox currentBox = sortedBoxes.front();
        nms_boxes.push_back(currentBox);
        sortedBoxes.erase(sortedBoxes.begin()); // 取完后从候选锚框中删除

        // 计算剩余锚框与合法锚框的IOU
        std::vector<bbox>::iterator it = sortedBoxes.begin();
        printf("sortedBoxes=%d, conf=%f\n", sortedBoxes.size(), it[0].confidence);
        while (it != sortedBoxes.end()){
            const bbox candidateBox = *it; // 取当前候选锚框
            if (currentBox.cls_id==candidateBox.cls_id)
            {
                float intersection = intersectionArea(currentBox, candidateBox); // 计算候选框和合法框的交集面积
                float iou = intersection / (boxArea(currentBox) + boxArea(candidateBox) - intersection); // 计算iou
                printf("iou=%f\n", iou);
                if (iou >= nms_threshold) 
                {sortedBoxes.erase(it);
                printf("delet\n");}  // 根据阈值过滤锚框,过滤完it指向下一个锚框
                else it++; // 保留当前锚框,判断下一个锚框
            }
            else
            {
                it++; // 保留当前锚框,判断下一个锚框
            }

        }
        printf("input_boxes=%d, nms_boxes=%d, sortedBoxes=%d\n", input_boxes.size(), nms_boxes.size(), sortedBoxes.size());
    }
    
    return ;
};

你可能感兴趣的:(c++,java,算法)