【科研】YOLOv8中anchor_points可视化(更新中)

目录

  • 写在前面
  • anchor-point可视化
    • YOLOv8官方源码解读
      • predictor

写在前面

感叹一下:如果GPT能在我刚上大学的时候出来,也许我能学的比现在好太多,毕竟大学有一个比自己优秀太多的人引导着是多么地捷径。

anchor-point可视化

YOLOv8官方源码解读

predictor

ultralytics/ultralytics/models/yolo/obb/predict.py中源码有一个OBBPredictor,继承原始的DetectionPredictor,此可视化主要是修改这个类的返回变量,加入一个feat_ind(nms后的有用anchor_point的索引值)。

def postprocess(self, preds, img, orig_imgs, return_feat_ind=False):
        """Post-processes predictions and returns a list of Results objects."""
        if return_feat_ind:
            preds, feat_ind = ops.non_max_suppression(
                preds,
                self.args.conf,
                self.args.iou,
                agnostic=self.args.agnostic_nms,
                max_det=self.args.max_det,
                nc=len(self.model.names),
                classes=self.args.classes,
                rotated=True,
                return_feat_ind=return_feat_ind,
            )
        else:
            preds = ops.non_max_suppression(
                preds,
                self.args.conf,
                self.args.iou,
                agnostic=self.args.agnostic_nms,
                max_det=self.args.max_det,
                nc=len(self.model.names),
                classes=self.args.classes,
                rotated=True,
                return_feat_ind=return_feat_ind,
            )

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
            rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
            rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
            # xywh, r, conf, cls
            obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
            anchor_points = self.count_anchor_point(feat_ind)
            results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb, anchor_points=anchor_points))
        if return_feat_ind:
            return results, feat_ind
        else:
            return results

可以看到首先要在nms函数里修改,当我们要求返回feat_ind时,将return_feat_ind改为true。

def non_max_suppression(
    prediction,
    conf_thres=0.25,

你可能感兴趣的:(YOLO)