在目标检测任务中,后处理阶段的 非极大值抑制(Non-Maximum Suppression, NMS) 是至关重要的一环,主要用于去除高度重叠的冗余预测框。然而,在某些场景中,不同类别的目标可能会被网络同时预测为多个相近的类别,例如:
为了解决这一问题,可以通过 分组 NMS(Grouped NMS) 的方法,将相近类别的目标归为一组,统一执行 NMS,从而减少这种类别混淆的情况。
在本文中,我们介绍了如何在现有 NMS 函数的基础上,改进实现了分组 NMS,并且延续了原有代码的高效性。
传统的 NMS 处理不同类别的目标是独立进行的,假设输入检测框的类别为 cls1 和 cls2,只会在相同类别的框中计算 IoU 并抑制冗余框。但如果模型对同一目标产生了多类别预测(如 “自行车” 和 “电动车”),传统 NMS 并不能很好地处理这种情况,而是会保留两个类别的预测框,从而导致误检或冗余。
原始NMS代码参考yolo的实现:
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - nm - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
mi = 5 + nc # mask start index
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box/Mask
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
mask = x[:, mi:] # zero columns if no masks
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = x[:, 5:mi].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
else:
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
# --------这里是后续修改的部分--------
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
# --------------------------------
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
return output
我们引入了一个新的参数 cls_groups,用于指定需要统一处理的类别组。通过对属于同一个组的类别进行归一化的类别标记,可以将它们视为同一类别来执行 NMS,从而避免同一目标被识别为多个类别。
以下是关键的改动部分:
# 分组 NMS 的处理方式
if class_groups is not None:
y = x.clone()
for i, group in enumerate(class_groups):
temp_cls = -(i + 1) # 使用负数表示分组类别的临时编号
for cls in group:
y[:, 5:6][y[:, 5:6] == cls] = temp_cls
c = y[:, 5:6] * (0 if agnostic else max_wh) # classes
else:
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
代码解读:
def group_nms(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False,
multi_label=False, labels=(), class_groups=None, max_det=300):
"""
Args:
prediction (Tensor): 模型输出的预测结果 [n, 6+num_classes],包括边界框、置信度和类别。
conf_thres (float): 置信度阈值,低于此值的检测框将被移除。
iou_thres (float): IoU 阈值,用于 NMS 的重叠过滤。
classes (list): 指定需要保留的类别,默认为 None 表示保留所有类别。
agnostic (bool): 是否忽略类别信息执行 NMS。
multi_label (bool): 是否支持多标签模式。
labels (list): 手动添加的先验标签。
class_groups (list): 类别分组列表,例如 [[0, 1], [2, 3]]。
max_det (int): 每张图片保留的最大检测框数。
Returns:
list: 每张图片的检测结果,格式为 [xyxy, conf, cls]。
"""
假设我们有以下场景:
可以通过如下方式调用改进后的函数:
# 定义类别分组
class_groups = [[0, 1], [2, 3]]
# 调用分组 NMS 函数
output = group_nms(prediction, conf_thres=0.3, iou_thres=0.5, class_groups=class_groups)
通过引入 分组 NMS,我们解决了传统 NMS 在处理相近类别目标时的不足,尤其是在类别混淆较高的场景中(如 “自行车” 和 “电动车”)。该方法在保持代码高效性的同时,显著提升了检测质量,适用于多种目标检测任务。