YOLODataset 和 BaseDataset 是 Ultralytics YOLO 框架中用于加载和处理数据集的两个关键类。
YOLODataset类(ultralytics/data/dataset.py)继承于 BaseDataset类(ultralytics/data/base.py)
BaseDataset
是一个基础数据集类,提供了加载图像、缓存数据、预处理数据等核心功能。它是所有数据集类的父类,为子类提供了通用的数据加载和处理逻辑。
主要功能:
- 图像加载:从指定路径加载图像文件。
- 缓存机制:支持将图像缓存到内存或磁盘,以加速训练。
- 数据预处理:包括图像大小调整、填充等操作。
- 标签处理:加载和更新标签信息。
- 数据增强:通过
build_transforms
方法支持数据增强。
关键方法:
get_img_files
:从指定路径加载图像文件。load_image
:加载单个图像并返回其原始和调整后的尺寸。cache_images
:将图像缓存到内存或磁盘。set_rectangle
:设置矩形训练模式。get_image_and_label
:获取图像和标签信息。build_transforms
:构建数据增强和预处理管道(需子类实现)。
YOLODataset
继承自 BaseDataset
,并在此基础上扩展了 YOLO 特定任务的功能。它支持 YOLO 的目标检测、实例分割、姿态估计和旋转框(OBB)等任务。
主要功能:
- 标签加载:从磁盘或缓存中加载 YOLO 格式的标签。
- 任务支持:根据任务类型(检测、分割、姿态估计、OBB)加载相应的标签。
- 数据增强:扩展了 YOLO 特定的数据增强逻辑(如 Mosaic、MixUp 等)。
- 标签格式处理:将标签转换为 YOLO 训练所需的格式。
关键方法:
cache_labels
:缓存标签并检查图像和标签的完整性。get_labels
:加载标签并返回 YOLO 训练所需的格式。build_transforms
:构建 YOLO 特定的数据增强和预处理管道。update_labels_info
:更新标签格式以适应不同任务。collate_fn
:将数据样本整理为批次。
YOLODataset
继承自 BaseDataset
,并在此基础上扩展了 YOLO 特定任务的功能。具体关系如下:
class YOLODataset(BaseDataset):
def __init__(self, *args, data=None, task="detect", **kwargs):
super().__init__(*args, channels=self.data["channels"], **kwargs)
YOLODataset
通过 super().__init__()
调用 BaseDataset
的初始化方法,继承了 BaseDataset
的所有属性和方法。YOLODataset
在初始化时增加了 task
参数,用于指定任务类型(检测、分割、姿态估计、OBB)。方法重写
- 标签加载:
YOLODataset
重写了get_labels
方法,支持从 YOLO 格式的标签文件中加载数据。- 数据增强:
YOLODataset
重写了build_transforms
方法,增加了 YOLO 特定的数据增强逻辑(如 Mosaic、MixUp 等)。- 任务支持:
YOLODataset
根据任务类型加载相应的标签,并处理成 YOLO 训练所需的格式,update_labels_info。
__init__
)步骤:
1. 设置任务类型2. 调用父类初始化
- 通过
super().__init__(*args, channels=self.data["channels"], **kwargs)
调用父类BaseDataset
的初始化函数。- 父类会加载图像文件、标签文件,并进行缓存和预处理。
父类初始化流程:
1. 初始化参数。
2. 加载图像文件 (
get_img_files
),返回找到的图像文件路径列表。3. 加载标签文件 (
get_labels
),get_labels函数子类YOLODataset进行了复写!所以调用的是子类方法。4. 缓存图像 (
cache_images
),以加快训练速度。5. 构建数据增强方法 (
build_transforms
)。
cache_labels
)cache_labels
方法用于缓存标签数据,并检查图像和标签的完整性。
get_labels
)get_labels
方法用于从缓存或磁盘加载标签数据,并准备用于训练。
下面是YOLODataset复写的函数。
def get_labels(self):
"""
Returns dictionary of labels for YOLO training.
This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
Returns:
(List[dict]): List of label dictionaries, each containing information about an image and its annotations.
"""
self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# Read cache
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels = cache["labels"]
if not labels:
LOGGER.warning(f"No images found in {cache_path}, training may not work correctly. {HELP_URL}")
self.im_files = [lb["im_file"] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f"Box and segment counts should be equal, but got len(segments) = {len_segments}, "
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
)
for lb in labels:
lb["segments"] = []
if len_cls == 0:
LOGGER.warning(f"No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
return labels
方法流程:
1. 获取标签文件路径:调用img2label_paths
方法,根据图像文件路径生成对应的标签文件路径。self.label_files = img2label_paths(self.im_files)
2. 加载缓存文件:尝试从缓存文件加载标签数据。
cache, exists = load_dataset_cache_file(cache_path), True
3. 验证缓存文件:检查缓存文件的版本和哈希值是否匹配。
assert cache["version"] == DATASET_CACHE_VERSION assert cache["hash"] == get_hash(self.label_files + self.im_files)
4. 更新图像文件列表:从缓存中提取图像文件路径,并更新
self.im_files
。self.im_files = [lb["im_file"] for lb in labels]
5. 检查标签完整性:检查标签数据是否完整(如边界框和分割掩码的数量是否一致)。
if len_segments and len_boxes != len_segments: LOGGER.warning("Box and segment counts should be equal...")
6. 返回标签数据:返回处理后的标签数据,供训练使用。
return labels
get_labels
方法的主要功能包括:
- 加载标签文件:从磁盘或缓存中加载标签数据。
- 验证标签完整性:检查标签数据是否有效(如是否存在、是否损坏等)。
- 准备标签数据:将标签数据转换为模型训练所需的格式。
- 返回标签数据:返回处理后的标签数据,供训练使用。
get_labels
方法的返回值是一个包含标签数据的列表,列表中的每个元素是一个字典,表示一张图像的标签信息。每个字典的键值对如下:
键名 | 描述 |
---|---|
im_file |
图像文件的路径。 |
shape |
图像的原始尺寸(高度、宽度)。 |
cls |
目标的类别索引,形状为 (n, 1) ,其中 n 是目标的数量。 |
bboxes |
目标的边界框,形状为 (n, 4) ,格式为 [x_center, y_center, width, height] 。 |
segments |
目标的分割掩码(如果任务为分割),形状为 (n, k, 2) ,其中 k 是点的数量。 |
keypoints |
目标的关键点(如果任务为姿态估计),形状为 (n, k, 3) ,其中 k 是点的数量。 |
normalized |
布尔值,表示边界框和关键点是否已归一化。 |
bbox_format |
边界框的格式(如 "xywh" )。 |
输出示例:
[
{
"im_file": "path/to/images/image1.jpg",
"shape": (640, 640),
"cls": [[0], [1]], # 两个目标,类别分别为 0 和 1
"bboxes": [[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1]], # 两个目标的边界框
"segments": [], # 分割掩码(如果任务为分割)
"keypoints": [], # 关键点(如果任务为姿态估计)
"normalized": True,
"bbox_format": "xywh"
},
{
"im_file": "path/to/images/image2.jpg",
"shape": (640, 640),
"cls": [[0]], # 一个目标,类别为 0
"bboxes": [[0.3, 0.4, 0.1, 0.2]], # 一个目标的边界框
"segments": [], # 分割掩码(如果任务为分割)
"keypoints": [], # 关键点(如果任务为姿态估计)
"normalized": True,
"bbox_format": "xywh"
}
]
4. 构建数据增强 (
build_transforms
)
build_transforms
方法用于构建数据增强和预处理管道。
update_labels_info
)update_labels_info
方法用于更新标签数据的格式,以支持不同任务(如目标检测、分割等)。
在训练过程中,YOLODataset
会通过 __getitem__
方法加载数据,并应用数据增强和预处理。
在 YOLODataset
中,数据打包的过程是通过 collate_fn
方法实现的。collate_fn
方法将多个样本(每张图像及其标签)打包成一个批次,以便输入到模型中进行训练。
collate_fn
方法的主要功能是将多个样本的数据(如图像、标签、边界框等)打包成一个批次。
1. 初始化批次字典:创建一个空字典
new_batch
,用于存储批次数据。new_batch = {}
2. 排序样本数据:确保每个样本的键值顺序一致,以便后续处理。
batch = [dict(sorted(b.items())) for b in batch]
3. 提取样本键值:获取所有样本的键(如
img
、bboxes
、cls
等),并将对应的值打包成列表。keys = batch[0].keys() values = list(zip(*[list(b.values()) for b in batch]))
4. 处理不同类型的数据:根据数据类型(如图像、边界框、类别等),使用不同的方式将数据打包成张量。
for i, k in enumerate(keys): value = values[i] if k in {"img", "text_feats"}: value = torch.stack(value, 0) # 图像数据直接堆叠 elif k == "visuals": value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True) # 填充序列 elif k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}: value = torch.cat(value, 0) # 标签数据拼接 new_batch[k] = value
5. 处理批次索引:为每个样本添加批次索引,以便在训练过程中区分不同样本。
new_batch["batch_idx"] = list(new_batch["batch_idx"]) for i in range(len(new_batch["batch_idx"])): new_batch["batch_idx"][i] += i # 添加目标图像索引 new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
6. 返回批次数据:返回打包好的批次数据。
return new_batch
在使用 DataLoader
加载数据时,collate_fn
方法会自动将多个样本打包成一个批次。以下是 DataLoader
中读取一批次数据的具体内容。
字段名 | 描述 |
---|---|
img |
图像数据,形状为 (batch_size, channels, height, width) 。 |
bboxes |
边界框数据,形状为 (total_objects, 4) ,格式为 [x_center, y_center, width, height] 。 |
cls |
类别数据,形状为 (total_objects, 1) 。 |
segments |
分割掩码数据,形状为 (total_objects, k, 2) ,其中 k 是点的数量。 |
keypoints |
关键点数据,形状为 (total_objects, k, 3) ,其中 k 是点的数量。 |
batch_idx |
批次索引,形状为 (total_objects,) ,用于区分不同样本的目标。 |
从 DataLoader 中读取一批次数据的示例输出:
{
"img": tensor([[[[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1]], # 图像数据
[[0.3, 0.4, 0.1, 0.2], [0.6, 0.7, 0.2, 0.1]]]),
"bboxes": tensor([[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1], # 边界框数据
[0.3, 0.4, 0.1, 0.2], [0.6, 0.7, 0.2, 0.1]]),
"cls": tensor([[0], [1], [0], [1]]), # 类别数据
"segments": tensor([], dtype=torch.float32), # 分割掩码数据
"keypoints": tensor([], dtype=torch.float32), # 关键点数据
"batch_idx": tensor([0, 0, 1, 1]) # 批次索引
}