基于分组 NMS 的检测模型后处理改进

引言

在目标检测任务中,后处理阶段的 非极大值抑制(Non-Maximum Suppression, NMS) 是至关重要的一环,主要用于去除高度重叠的冗余预测框。然而,在某些场景中,不同类别的目标可能会被网络同时预测为多个相近的类别,例如:

  • 交通工具检测场景:同一辆车可能被误检测为 “自行车” 和 “电动车”。
  • 动物检测场景:同一只动物可能被误检测为 “狼” 和 “狗”。
  • 家电检测场景:同一台设备可能被误检测为 “微波炉” 和 “电烤箱”。

为了解决这一问题,可以通过 分组 NMS(Grouped NMS) 的方法,将相近类别的目标归为一组,统一执行 NMS,从而减少这种类别混淆的情况。

在本文中,我们介绍了如何在现有 NMS 函数的基础上,改进实现了分组 NMS,并且延续了原有代码的高效性。
基于分组 NMS 的检测模型后处理改进_第1张图片

原始 NMS 的局限性

传统的 NMS 处理不同类别的目标是独立进行的,假设输入检测框的类别为 cls1 和 cls2,只会在相同类别的框中计算 IoU 并抑制冗余框。但如果模型对同一目标产生了多类别预测(如 “自行车” 和 “电动车”),传统 NMS 并不能很好地处理这种情况,而是会保留两个类别的预测框,从而导致误检或冗余。
原始NMS代码参考yolo的实现:

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nm=0,  # number of masks
):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    if isinstance(prediction, (list, tuple)):  # YOLOv5 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    device = prediction.device
    mps = 'mps' in device.type  # Apple MPS
    if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
        prediction = prediction.cpu()
    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - nm - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 0.5 + 0.05 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    mi = 5 + nc  # mask start index
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box/Mask
        box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        mask = x[:, mi:]  # zero columns if no masks

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = x[:, 5:mi].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence
        else:
            x = x[x[:, 4].argsort(descending=True)]  # sort by confidence
		
		# --------这里是后续修改的部分--------
        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        # --------------------------------
        
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if mps:
            output[xi] = output[xi].to(device)
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output

改进后的分组 NMS 逻辑

我们引入了一个新的参数 cls_groups,用于指定需要统一处理的类别组。通过对属于同一个组的类别进行归一化的类别标记,可以将它们视为同一类别来执行 NMS,从而避免同一目标被识别为多个类别。

核心实现与代码解析

以下是关键的改动部分:

# 分组 NMS 的处理方式
if class_groups is not None:
    y = x.clone()
    for i, group in enumerate(class_groups):
        temp_cls = -(i + 1)  # 使用负数表示分组类别的临时编号
        for cls in group:
            y[:, 5:6][y[:, 5:6] == cls] = temp_cls
    c = y[:, 5:6] * (0 if agnostic else max_wh)  # classes
else:
    c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes

代码解读:

  • 1. 参数 class_groups:
    新增参数 class_groups,用于定义需要分组的类别。例如,class_groups=[[0, 1], [2, 3]] 表示类别 0 和 1 归为一组,类别 2 和 3 归为另一组。
  • 2. 临时类别编号:
    为了在 NMS 中处理分组类别,我们为每一组类别分配一个临时类别编号:
    使用负数(如 -1, -2 等)作为临时类别编号,避免与原始类别编号冲突。
    遍历每一组类别,用 y[:, 5:6] 替换对应的类别编号。
  • 3. 保持代码兼容性:
    如果未指定 class_groups,代码会按照原始方式执行 NMS,确保改动对现有功能无影响。
  • 4. 效率优化:
    通过在检测框信息 x 的基础上复制出 y 并仅对类别进行操作,减少了对其他部分数据的影响,延续了代码的高效性。

改进后的函数接口

def group_nms(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, 
              multi_label=False, labels=(), class_groups=None, max_det=300):
    """
    Args:
        prediction (Tensor): 模型输出的预测结果 [n, 6+num_classes],包括边界框、置信度和类别。
        conf_thres (float): 置信度阈值,低于此值的检测框将被移除。
        iou_thres (float): IoU 阈值,用于 NMS 的重叠过滤。
        classes (list): 指定需要保留的类别,默认为 None 表示保留所有类别。
        agnostic (bool): 是否忽略类别信息执行 NMS。
        multi_label (bool): 是否支持多标签模式。
        labels (list): 手动添加的先验标签。
        class_groups (list): 类别分组列表,例如 [[0, 1], [2, 3]]。
        max_det (int): 每张图片保留的最大检测框数。

    Returns:
        list: 每张图片的检测结果,格式为 [xyxy, conf, cls]。
    """

使用示例

假设我们有以下场景:

  • 检测类别包括 0: 自行车、1: 电动车、2: 汽车、3: 卡车。
  • 希望将 自行车 和 电动车 归为一组,汽车 和 卡车 归为另一组。

可以通过如下方式调用改进后的函数:

# 定义类别分组
class_groups = [[0, 1], [2, 3]]

# 调用分组 NMS 函数
output = group_nms(prediction, conf_thres=0.3, iou_thres=0.5, class_groups=class_groups)

总结

通过引入 分组 NMS,我们解决了传统 NMS 在处理相近类别目标时的不足,尤其是在类别混淆较高的场景中(如 “自行车” 和 “电动车”)。该方法在保持代码高效性的同时,显著提升了检测质量,适用于多种目标检测任务。

你可能感兴趣的:(目标检测,算法与优化,目标检测,深度学习,python)