D-FINE使用pth权重批量推理可视化图片

关于D-FINE相关的内容可参考下面这篇博客: 

论文解读:ICLR2025 | D-FINE_d-fine: redefine regression task in detrs as fine--CSDN博客文章浏览阅读949次,点赞18次,收藏28次。D-FINE 是一款功能强大的实时物体检测器,它将 DETRs 中的边界框回归任务重新定义为细粒度分布细化(FDR),并引入了全局最优定位自蒸馏(GO-LSD),在不引入额外推理和训练成本的情况下表现出了最佳性能_d-fine: redefine regression task in detrs as fine-grained distribution refinhttps://blog.csdn.net/Jacknbv/article/details/148002595?spm=1001.2014.3001.5501 在D-FINE官方源码的根目录下的tools/inference/torch.inf.py进行一些修改:

"""
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
"""

import os
import sys

import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from src.core import YAMLConfig


# 类别名称-->类别可以增加或者减少,对应的COLORS也做一些改变
NAMES = ['car', 'dog', 'person', 'cow']
# 颜色 (BGR格式)
COLORS = [(0, 0, 255), (255, 0, 255), (3, 97, 255), (255, 0, 0)]
# 字体文件路径-->没有字体文件会采用默认的字体
FONT_PATH = "Arial.ttf" 
# --- ---


def plot_one_box(box, im, color=(128, 128, 128), txt_color=(255, 255, 255), label=None, line_width=3):
    """
    Plots one xyxy box on image im with label, using a custom font.
    Text position is adjusted to stay within image boundaries.
    """
    assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
    lw = line_width or max(round(min(im.shape[:2]) * 0.003), 2)

    c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
    cv2.rectangle(im, c1, c2, color, thickness=lw, lineType=cv2.LINE_AA)

    if label:
        # 为了给Pillow使用,将OpenCV的BGR颜色转换为RGB
        color_rgb = (color[2], color[1], color[0])

        # 使用Pillow来绘制带自定义字体的文本
        try:
            # 将字体大小的计算基准和最小值都放大一倍
            font_size = max(round(lw * 12), 30)  # 根据线宽调整字体大小
            font = ImageFont.truetype(FONT_PATH, size=font_size)
        except IOError:
            print(f"警告: 字体文件未找到于 '{FONT_PATH}'. 将使用默认字体。")
            font = ImageFont.load_default()

        # 将OpenCV图像 (BGR) 转换为Pillow图像 (RGB)
        im_pil = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(im_pil)

        # 获取文本尺寸
        if hasattr(font, 'getbbox'):  # Pillow 9.2.0+
            _, _, txt_width, txt_height = font.getbbox(label)
        else:  # 旧版本 Pillow
            txt_width, txt_height = draw.textsize(label, font)

        # ------------------- 自适应文本位置逻辑 -------------------#
        # 默认位置:框体上方
        label_y = c1[1] - txt_height - 5
        bg_y0 = c1[1] - txt_height - 6
        bg_y1 = c1[1]

        # --------------- 如果文本超出图像顶部,则移动到框体内部下方 ---------------#
        if label_y < 0:
            label_y = c2[1] - txt_height - 5
            bg_y0 = c2[1] - txt_height - 6
            bg_y1 = c2[1]

        # 绘制背景框,使用转换后的RGB颜色
        bg_coords = [(c1[0], bg_y0), (c1[0] + txt_width, bg_y1)]
        draw.rectangle(bg_coords, fill=color_rgb)

        # 绘制文本
        draw.text((c1[0] + 3, label_y), label, font=font, fill=txt_color)

        # 将Pillow图像转换回OpenCV图像 (BGR)
        im = cv2.cvtColor(np.array(im_pil), cv2.COLOR_RGB2BGR)

    return im


def draw_detections(image_np, labels, boxes, scores, thrh=0.4):
    """
    Draws all detections on a single image.
    """
    img_res = image_np.copy()
    valid_indices = scores > thrh
    lab = labels[valid_indices]
    box = boxes[valid_indices]
    scrs = scores[valid_indices]

    for j, b in enumerate(box):
        label_id = lab[j].item()
        if 0 <= label_id < len(NAMES):
            class_name = NAMES[label_id]
            color = COLORS[label_id]
        else:
            class_name = f"ID:{label_id}"
            color = (128, 128, 128)  # Default gray for unknown classes

        label_text = f"{class_name} {round(scrs[j].item(), 2)}"
        # 使用白色字体
        img_res = plot_one_box(b, img_res, label=label_text, color=color, txt_color=(255, 255, 255), line_width=3)

    return img_res


def process_image(model, device, image_path, output_path, conf_thres):
    """
    Reads an image, performs inference, draws detections, and saves the result.
    """
    try:
        im_bgr = cv2.imread(image_path)
        if im_bgr is None:
            print(f"Warning: Could not read image {image_path}, skipping.")
            return

        im_pil = Image.fromarray(cv2.cvtColor(im_bgr, cv2.COLOR_BGR2RGB))
        w, h = im_pil.size
        orig_size = torch.tensor([[w, h]]).to(device)

        transforms = T.Compose(
            [
                T.Resize((640, 640)),
                T.ToTensor(),
            ]
        )
        im_data = transforms(im_pil).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(im_data, orig_size)
        
        # D-FINE model returns a list containing one set of results for the batch
        labels, boxes, scores = output[0], output[1], output[2]
        
        # The results are for the first (and only) image in the batch
        labels_img, boxes_img, scores_img = labels[0], boxes[0], scores[0]

        result_image = draw_detections(im_bgr, labels_img, boxes_img, scores_img, thrh=conf_thres)

        cv2.imwrite(output_path, result_image)

    except Exception as e:
        print(f"Error processing {image_path}: {e}")


def main(args):
    """Main function"""
    cfg = YAMLConfig(args.config, resume=args.resume)

    if "HGNetv2" in cfg.yaml_cfg:
        cfg.yaml_cfg["HGNetv2"]["pretrained"] = False

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        if "ema" in checkpoint:
            state = checkpoint["ema"]["module"]
        else:
            state = checkpoint["model"]
    else:
        raise AttributeError("Only support resume to load model.state_dict by now.")

    # Load train mode state and convert to deploy mode
    cfg.model.load_state_dict(state)

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = cfg.model.deploy()
            self.postprocessor = cfg.postprocessor.deploy()

        def forward(self, images, orig_target_sizes):
            outputs = self.model(images)
            outputs = self.postprocessor(outputs, orig_target_sizes)
            return outputs

    device = torch.device(args.device)
    model = Model().to(device)
    model.eval()

    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)

    # Read image paths from the txt file
    with open(args.txt_file, 'r') as f:
        image_paths = [line.strip() for line in f if line.strip()]

    print(f"Found {len(image_paths)} images to process.")

    for image_path in tqdm(image_paths, desc="Processing Images"):
        if not os.path.exists(image_path):
            print(f"Warning: Image path not found: {image_path}, skipping.")
            continue
        
        base_filename = os.path.basename(image_path)
        output_path = os.path.join(args.output_dir, base_filename)

        process_image(model, device, image_path, output_path, args.conf_thres)

    print("All images processed.")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config", type=str, required=True)
    parser.add_argument("-r", "--resume", type=str, required=True)
    parser.add_argument("--txt_file", type=str, required=True, help="Path to a .txt file containing image paths.")
    parser.add_argument("-o", "--output_dir", type=str, required=True, help="Directory to save the output images.")
    parser.add_argument("--conf_thres", type=float, default=0.4, help="Confidence threshold for detection filtering.")
    parser.add_argument("-d", "--device", type=str, default="cpu")
    args = parser.parse_args()
    main(args)

终端运行命令: 

python tools/inference/torch_inf.py \
       -c configs/dfine/custom/dfine_hgnetv2_l_custom.yml \
       -r output/xxxxxx/xxxxxx.pth \
       # 将需要推理的图片路径存入到txt文件里
       --txt_file /data/xxx.txt \
       -o output/xxxxx \
       --conf_thres 0.35 \
       -d cuda:0

你可能感兴趣的:(代码调试,深度学习,人工智能,python,目标检测,计算机视觉)