关于D-FINE相关的内容可参考下面这篇博客:
论文解读:ICLR2025 | D-FINE_d-fine: redefine regression task in detrs as fine--CSDN博客文章浏览阅读949次,点赞18次,收藏28次。D-FINE 是一款功能强大的实时物体检测器,它将 DETRs 中的边界框回归任务重新定义为细粒度分布细化(FDR),并引入了全局最优定位自蒸馏(GO-LSD),在不引入额外推理和训练成本的情况下表现出了最佳性能_d-fine: redefine regression task in detrs as fine-grained distribution refinhttps://blog.csdn.net/Jacknbv/article/details/148002595?spm=1001.2014.3001.5501 在D-FINE官方源码的根目录下的tools/inference/torch.inf.py进行一些修改:
"""
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
"""
import os
import sys
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from src.core import YAMLConfig
# 类别名称-->类别可以增加或者减少,对应的COLORS也做一些改变
NAMES = ['car', 'dog', 'person', 'cow']
# 颜色 (BGR格式)
COLORS = [(0, 0, 255), (255, 0, 255), (3, 97, 255), (255, 0, 0)]
# 字体文件路径-->没有字体文件会采用默认的字体
FONT_PATH = "Arial.ttf"
# --- ---
def plot_one_box(box, im, color=(128, 128, 128), txt_color=(255, 255, 255), label=None, line_width=3):
"""
Plots one xyxy box on image im with label, using a custom font.
Text position is adjusted to stay within image boundaries.
"""
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
lw = line_width or max(round(min(im.shape[:2]) * 0.003), 2)
c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(im, c1, c2, color, thickness=lw, lineType=cv2.LINE_AA)
if label:
# 为了给Pillow使用,将OpenCV的BGR颜色转换为RGB
color_rgb = (color[2], color[1], color[0])
# 使用Pillow来绘制带自定义字体的文本
try:
# 将字体大小的计算基准和最小值都放大一倍
font_size = max(round(lw * 12), 30) # 根据线宽调整字体大小
font = ImageFont.truetype(FONT_PATH, size=font_size)
except IOError:
print(f"警告: 字体文件未找到于 '{FONT_PATH}'. 将使用默认字体。")
font = ImageFont.load_default()
# 将OpenCV图像 (BGR) 转换为Pillow图像 (RGB)
im_pil = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(im_pil)
# 获取文本尺寸
if hasattr(font, 'getbbox'): # Pillow 9.2.0+
_, _, txt_width, txt_height = font.getbbox(label)
else: # 旧版本 Pillow
txt_width, txt_height = draw.textsize(label, font)
# ------------------- 自适应文本位置逻辑 -------------------#
# 默认位置:框体上方
label_y = c1[1] - txt_height - 5
bg_y0 = c1[1] - txt_height - 6
bg_y1 = c1[1]
# --------------- 如果文本超出图像顶部,则移动到框体内部下方 ---------------#
if label_y < 0:
label_y = c2[1] - txt_height - 5
bg_y0 = c2[1] - txt_height - 6
bg_y1 = c2[1]
# 绘制背景框,使用转换后的RGB颜色
bg_coords = [(c1[0], bg_y0), (c1[0] + txt_width, bg_y1)]
draw.rectangle(bg_coords, fill=color_rgb)
# 绘制文本
draw.text((c1[0] + 3, label_y), label, font=font, fill=txt_color)
# 将Pillow图像转换回OpenCV图像 (BGR)
im = cv2.cvtColor(np.array(im_pil), cv2.COLOR_RGB2BGR)
return im
def draw_detections(image_np, labels, boxes, scores, thrh=0.4):
"""
Draws all detections on a single image.
"""
img_res = image_np.copy()
valid_indices = scores > thrh
lab = labels[valid_indices]
box = boxes[valid_indices]
scrs = scores[valid_indices]
for j, b in enumerate(box):
label_id = lab[j].item()
if 0 <= label_id < len(NAMES):
class_name = NAMES[label_id]
color = COLORS[label_id]
else:
class_name = f"ID:{label_id}"
color = (128, 128, 128) # Default gray for unknown classes
label_text = f"{class_name} {round(scrs[j].item(), 2)}"
# 使用白色字体
img_res = plot_one_box(b, img_res, label=label_text, color=color, txt_color=(255, 255, 255), line_width=3)
return img_res
def process_image(model, device, image_path, output_path, conf_thres):
"""
Reads an image, performs inference, draws detections, and saves the result.
"""
try:
im_bgr = cv2.imread(image_path)
if im_bgr is None:
print(f"Warning: Could not read image {image_path}, skipping.")
return
im_pil = Image.fromarray(cv2.cvtColor(im_bgr, cv2.COLOR_BGR2RGB))
w, h = im_pil.size
orig_size = torch.tensor([[w, h]]).to(device)
transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
im_data = transforms(im_pil).unsqueeze(0).to(device)
with torch.no_grad():
output = model(im_data, orig_size)
# D-FINE model returns a list containing one set of results for the batch
labels, boxes, scores = output[0], output[1], output[2]
# The results are for the first (and only) image in the batch
labels_img, boxes_img, scores_img = labels[0], boxes[0], scores[0]
result_image = draw_detections(im_bgr, labels_img, boxes_img, scores_img, thrh=conf_thres)
cv2.imwrite(output_path, result_image)
except Exception as e:
print(f"Error processing {image_path}: {e}")
def main(args):
"""Main function"""
cfg = YAMLConfig(args.config, resume=args.resume)
if "HGNetv2" in cfg.yaml_cfg:
cfg.yaml_cfg["HGNetv2"]["pretrained"] = False
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
if "ema" in checkpoint:
state = checkpoint["ema"]["module"]
else:
state = checkpoint["model"]
else:
raise AttributeError("Only support resume to load model.state_dict by now.")
# Load train mode state and convert to deploy mode
cfg.model.load_state_dict(state)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.model = cfg.model.deploy()
self.postprocessor = cfg.postprocessor.deploy()
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
outputs = self.postprocessor(outputs, orig_target_sizes)
return outputs
device = torch.device(args.device)
model = Model().to(device)
model.eval()
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# Read image paths from the txt file
with open(args.txt_file, 'r') as f:
image_paths = [line.strip() for line in f if line.strip()]
print(f"Found {len(image_paths)} images to process.")
for image_path in tqdm(image_paths, desc="Processing Images"):
if not os.path.exists(image_path):
print(f"Warning: Image path not found: {image_path}, skipping.")
continue
base_filename = os.path.basename(image_path)
output_path = os.path.join(args.output_dir, base_filename)
process_image(model, device, image_path, output_path, args.conf_thres)
print("All images processed.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, required=True)
parser.add_argument("-r", "--resume", type=str, required=True)
parser.add_argument("--txt_file", type=str, required=True, help="Path to a .txt file containing image paths.")
parser.add_argument("-o", "--output_dir", type=str, required=True, help="Directory to save the output images.")
parser.add_argument("--conf_thres", type=float, default=0.4, help="Confidence threshold for detection filtering.")
parser.add_argument("-d", "--device", type=str, default="cpu")
args = parser.parse_args()
main(args)
终端运行命令:
python tools/inference/torch_inf.py \
-c configs/dfine/custom/dfine_hgnetv2_l_custom.yml \
-r output/xxxxxx/xxxxxx.pth \
# 将需要推理的图片路径存入到txt文件里
--txt_file /data/xxx.txt \
-o output/xxxxx \
--conf_thres 0.35 \
-d cuda:0