Transformer实战-系列教程15:DETR 源码解读2(ConvertCocoPolysToMask类)

Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

3、ConvertCocoPolysToMask类

位置:datasets/coco.py/ConvertCocoPolysToMask类

ConvertCocoPolysToMask类主要是进行数据预处理,主要在CocoDetection类中被调用

从的ConvertCocoPolysToMask类的代码来看,主要涉及到以下几种计算机视觉任务的数据预处理步骤:

  1. 物体检测(Object Detection)
    • 体现:通过处理bbox(边界框)信息。代码中提取和调整bbox坐标来适应物体检测任务的需求。
  2. 实例分割(Instance Segmentation)
    • 体现:如果return_masks为True,将COCO多边形标注(segmentation)转换为掩码(mask)。这对于实例分割任务来说是必要的,因为它需要精确地区分图像中各个对象的形状。
  3. 姿态估计(Pose Estimation)
    • 体现:通过处理keypoints信息。当标注中包含关键点数据时,代码会提取这些数据,这些数据对于识别和估计图像中人物的姿态非常有用。
class ConvertCocoPolysToMask(object):
    def __init__(self, return_masks=False):
        self.return_masks = return_masks
    def __call__(self, image, target):
        w, h = image.size
        image_id = target["image_id"]
        image_id = torch.tensor([image_id])
        anno = target["annotations"]
        anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
        boxes = [obj["bbox"] for obj in anno] # x y w h
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=w)
        boxes[:, 1::2].clamp_(min=0, max=h)
        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)
        if self.return_masks:
            segmentations = [obj["segmentation"] for obj in anno]
            masks = convert_coco_poly_to_mask(segmentations, h, w)
        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        if self.return_masks:
            masks = masks[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]
        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        if self.return_masks:
            target["masks"] = masks
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]
        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])
        return image, target
  1. 定义ConvertCocoPolysToMask类,用于处理COCO数据集的转换
  2. 类的初始化方法,参数return_masks用于控制是否返回标注的掩码信息
  3. return_masks
  4. 可调用方法,接收两个参数:image(PIL图像对象)和target(包含图像标注信息的字典)
  5. 图像w、h, 427 ∗ 640 427*640 427640,每张图片读进来的长宽都可能不一样
  6. 获取图像id,image id: 538686]
  7. 将id转化为Tensor,image id: tensor([538686])
  8. 获取标签的标注信息,包含面积、bbox框的长宽xy四个值、类别id、图像id、分割的标注信息
  9. 过滤标注信息,过滤掉有重叠框的,只保留对单个物体的框,包含重叠物体的不要,如果iscrowd为1,表示这个标注包含的是一个对象群,而不是单个对象
  10. 获取所有框,[[62.37, 135.48, 184.94, 364.52],…, [107.99, 46.17, 101.51, 157.66]]
  11. 框的数据转化为Tensor
  12. 将x、y、w、h
  13. 转化为
  14. x1、y1、x2、y2,tensor([[ 62.3700, 135.4800, 247.3100, 500.0000],…, [107.9900, 46.1700, 209.5000, 203.8300]])
  15. 获取当前图像的所有类别标签(可能对应有多个类别),[19, 21, 21, 1]
  16. 转化为Tensor,tensor([19, 21, 21, 1])
  17. 是否进行掩码转换
  18. 提取分割信息
  19. 调用函数将分割信息转化为掩码
  20. 初始化 keypoints (姿态估计任务使用)变量
  21. 判断标注信息中是否包含 keypoints 信息
  22. 提取所有标注的 keypoints 信息
  23. 将 keypoints 列表转换为PyTorch张量
  24. 获取 keypoints 的数量
  25. 判断是否存在 keypoints
  26. 重塑 keypoints 张量
  27. 过滤掉不合逻辑的边界框(即右下角坐标不大于左上角坐标的边界框),因为在标注数据的时候,外包人员如果没有按照标注说明去标,拉框不是从上面往下框住物体,而是从下往上,这会影响两个点的顺序判断
  28. 使用keep数组过滤边界框,保留有效的边界框
  29. 同样使用keep数组过滤类别ID,保留与有效边界框对应的类别ID
  30. 判断是否有掩码
  31. 使用keep数组过滤掩码,保留与有效边界框对应的掩码
  32. 如果 keypoints 信息存在
  33. 使用keep数组过滤 keypoints 信息
  34. 初始化一个新的字典target,用于存储处理后的标注信息
  35. 将过滤后的边界框信息添加到target字典
  36. 将过滤后的类别ID添加到target字典
  37. 判断是否有掩码
  38. 将过滤后的掩码添加到target字典
  39. 将图像ID添加到target字典
  40. 如果存在关键点信息
  41. 将过滤后的关键点信息添加到target字典
  42. 提取所有标注的面积信息,并转换为PyTorch张量
  43. 将过滤后的iscrowd信息添加到target字典,如果iscrowd为1,表示这个标注包含的是一个对象群,而不是单个对象
  44. 分别将原始图像的高度和宽度作为orig_size和size添加到target字典。这两个字段通常用于后续的处理或数据恢复步骤
  45. 返回处理后的图像和更新后的target字典

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

你可能感兴趣的:(Transformer实战,transformer,深度学习,计算机视觉,DETR,物体检测,pytorch)