YOLOv13_SSOD:基于超图关联增强的半监督目标检测框架(原创创新算法)

YOLOv13_SSOD:基于超图关联增强的半监督目标检测框架

项目背景

随着深度学习技术的快速发展,目标检测在各个领域都取得了显著的进展。然而,现有的监督学习方法在实际应用中面临着标注数据稀缺、泛化能力不足等挑战。特别是在火灾烟雾检测、工业质检等特定场景中,获取大量高质量标注数据的成本极高。

为了解决这一问题,本项目基于最新发布的YOLOv13架构,结合EfficientTeacher半监督学习框架,提出了YOLOv13_SSOD(YOLOv13 Semi-Supervised Object Detection)算法,旨在利用大量无标注数据提升模型的检测性能和泛化能力。

项目概述

YOLOv13_SSOD是一个创新的半监督目标检测框架,它继承了YOLOv13的超图关联增强机制和全流程聚合分发范式,同时集成了半监督学习的优势,能够有效利用无标注数据进行模型训练。

主要特点:

  • 基于YOLOv13的先进架构设计
  • 集成EfficientTeacher半监督学习框架
  • 支持多种数据增强策略
  • 提供完整的训练和推理流程
  • 在有限标注数据下显著提升检测性能

算法架构设计

1. YOLOv13基础架构适配

YOLOv13的结构如图所示
YOLOv13_SSOD:基于超图关联增强的半监督目标检测框架(原创创新算法)_第1张图片

为了适配半监督学习框架,我们对YOLOv13进行了以下关键修改:

1.1 Anchor-Based回归适配

虽然YOLOv13原本采用Anchor-Free设计,但考虑到半监督学习中伪标签生成的稳定性,我们将其改造为Anchor-Based架构:

class YOLOv13_SSOD_Head(nn.Module):
    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):
        super().__init__()
        self.nc = nc
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.zeros(1)] * self.nl
        self.anchor_grid = [torch.zeros(1)] * self.nl
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))
        
        # 集成HyperACE模块
        self.hyper_ace = HyperACE(ch)
        
        # 检测头
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)
        
    def forward(self, x):
        # 超图关联增强
        x = self.hyper_ace(x)
        
        z = []
        for i in range(self.nl):
            x[i] = self.m[i](x[i])
            bs, _, ny, nx = x[i].shape
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
            z.append(x[i])
        return x if self.training else (torch.cat(z, 1), x)
1.2 超图关联增强模块保留

保留YOLOv13的核心创新——HyperACE模块,在半监督学习中发挥重要作用:

class HyperACE_SSOD(nn.Module):
    def __init__(self, channels, num_scales=3):
        super().__init__()
        self.num_scales = num_scales
        self.channels = channels
        
        # 超图构建网络
        self.hypergraph_builder = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, c//4, 1),
                nn.BatchNorm2d(c//4),
                nn.ReLU(inplace=True),
                nn.Conv2d(c//4, c, 1)
            ) for c in channels
        ])
        
        # 自适应关联增强
        self.correlation_enhancer = nn.MultiheadAttention(
            embed_dim=sum(channels), 
            num_heads=8, 
            dropout=0.1
        )
        
    def forward(self, features):
        # 构建超图结构
        hypergraph_features = []
        for i, feat in enumerate(features):
            enhanced = self.hypergraph_builder[i](feat)
            hypergraph_features.append(enhanced)
        
        # 跨尺度特征关联
        concatenated = torch.cat([
            F.adaptive_avg_pool2d(f, 1).flatten(2) 
            for f in hypergraph_features
        ], dim=2)
        
        # 自适应关联增强
        enhanced, _ = self.correlation_enhancer(
            concatenated, concatenated, concatenated
        )
        
        return self.redistribute_features(enhanced, features)

2. EfficientTeacher半监督框架集成

网络架构如图所示

2.1 教师-学生网络架构
class YOLOv13_EfficientTeacher(nn.Module):
    def __init__(self, cfg, nc=80):
        super().__init__()
        self.nc = nc
        
        # 学生网络(YOLOv13_SSOD)
        self.student = YOLOv13_SSOD(cfg, nc=nc)
        
        # 教师网络(EMA更新)
        self.teacher = YOLOv13_SSOD(cfg, nc=nc)
        
        # 冻结教师网络参数
        for param in self.teacher.parameters():
            param.requires_grad = False
            
        # EMA更新参数
        self.ema_momentum = 0.9996
        
    def update_teacher(self):
        """使用EMA更新教师网络"""
        for teacher_param, student_param in zip(
            self.teacher.parameters(), 
            self.student.parameters()
        ):
            teacher_param.data = (
                self.ema_momentum * teacher_param.data + 
                (1 - self.ema_momentum) * student_param.data
            )
2.2 伪标签生成与筛选
class PseudoLabelGenerator:
    def __init__(self, conf_threshold=0.7, nms_threshold=0.5):
        self.conf_threshold = conf_threshold
        self.nms_threshold = nms_threshold
        
    def generate_pseudo_labels(self, teacher_predictions, augmented_images):
        """
        生成高质量伪标签
        """
        pseudo_labels = []
        
        for pred, img in zip(teacher_predictions, augmented_images):
            # 置信度筛选
            high_conf_mask = pred[..., 4] > self.conf_threshold
            filtered_pred = pred[high_conf_mask]
            
            # NMS去重
            if len(filtered_pred) > 0:
                keep_indices = nms(
                    filtered_pred[:, :4], 
                    filtered_pred[:, 4], 
                    self.nms_threshold
                )
                final_pred = filtered_pred[keep_indices]
                pseudo_labels.append(final_pred)
            else:
                pseudo_labels.append(torch.empty(0, 5))
                
        return pseudo_labels
    
    def adaptive_threshold_adjustment(self, epoch, max_epochs):
        """
        自适应调整置信度阈值
        """
        # 训练初期使用较高阈值,后期逐渐降低
        progress = epoch / max_epochs
        self.conf_threshold = 0.9 - 0.2 * progress

训练策略与数据增强

1. 强弱数据增强策略

class StrongWeakAugmentation:
    def __init__(self):
        # 弱增强(教师网络)
        self.weak_aug = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        # 强增强(学生网络)
        self.strong_aug = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.RandomRotate90(p=0.2),
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            A.GaussianBlur(blur_limit=3, p=0.3),
            A.Cutout(num_holes=8, max_h_size=32, max_w_size=32, p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

2. 损失函数设计

class YOLOv13_SSOD_Loss(nn.Module):
    def __init__(self, nc=80, lambda_unsup=2.0, lambda_consistency=1.0):
        super().__init__()
        self.nc = nc
        self.lambda_unsup = lambda_unsup
        self.lambda_consistency = lambda_consistency
        
        # 监督损失
        self.supervised_loss = YOLOv13_Loss(nc)
        
        # 一致性损失
        self.consistency_loss = nn.MSELoss()
        
    def forward(self, predictions, targets, epoch):
        """
        计算半监督损失
        """
        labeled_pred, unlabeled_pred_student, unlabeled_pred_teacher = predictions
        labeled_targets, pseudo_labels = targets
        
        # 监督损失
        sup_loss = self.supervised_loss(labeled_pred, labeled_targets)
        
        # 无监督损失(伪标签)
        if len(pseudo_labels) > 0:
            unsup_loss = self.supervised_loss(unlabeled_pred_student, pseudo_labels)
            # 动态权重调整
            unsup_weight = self.lambda_unsup * min(1.0, epoch / 100)
            unsup_loss = unsup_weight * unsup_loss
        else:
            unsup_loss = torch.tensor(0.0).to(labeled_pred.device)
        
        # 一致性损失
        consistency_loss = self.consistency_loss(
            unlabeled_pred_student, 
            unlabeled_pred_teacher.detach()
        )
        consistency_loss = self.lambda_consistency * consistency_loss
        
        total_loss = sup_loss + unsup_loss + consistency_loss
        
        return {
            'total_loss': total_loss,
            'sup_loss': sup_loss,
            'unsup_loss': unsup_loss,
            'consistency_loss': consistency_loss
        }

火灾烟雾检测应用案例

1. 数据集准备

class FireSmokeDataset(Dataset):
    def __init__(self, data_dir, labeled_ratio=0.3, mode='train'):
        self.data_dir = data_dir
        self.mode = mode
        self.labeled_ratio = labeled_ratio
        
        # 加载数据路径
        self.image_paths = glob.glob(os.path.join(data_dir, '**/*.jpg'), recursive=True)
        
        # 划分标注和未标注数据
        if mode == 'train':
            labeled_size = int(len(self.image_paths) * labeled_ratio)
            self.labeled_paths = self.image_paths[:labeled_size]
            self.unlabeled_paths = self.image_paths[labeled_size:]
        
        # 数据增强
        self.augmentation = StrongWeakAugmentation()
        
    def __getitem__(self, idx):
        if self.mode == 'labeled':
            img_path = self.labeled_paths[idx]
            label_path = img_path.replace('.jpg', '.txt')
            
            # 加载图像和标签
            image = cv2.imread(img_path)
            labels = self.load_labels(label_path)
            
            # 弱增强
            augmented = self.augmentation.weak_aug(image=image, bboxes=labels)
            
            return {
                'image': augmented['image'],
                'labels': augmented['bboxes'],
                'path': img_path
            }
            
        elif self.mode == 'unlabeled':
            img_path = self.unlabeled_paths[idx]
            image = cv2.imread(img_path)
            
            # 强弱增强
            weak_aug = self.augmentation.weak_aug(image=image)
            strong_aug = self.augmentation.strong_aug(image=image)
            
            return {
                'weak_image': weak_aug['image'],
                'strong_image': strong_aug['image'],
                'path': img_path
            }

2. 训练流程

class YOLOv13_SSOD_Trainer:
    def __init__(self, model, train_loader, val_loader, cfg):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.cfg = cfg
        
        # 优化器
        self.optimizer = torch.optim.SGD(
            model.student.parameters(), 
            lr=cfg.lr, 
            momentum=0.9, 
            weight_decay=1e-4
        )
        
        # 学习率调度器
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=cfg.epochs
        )
        
        # 损失函数
        self.criterion = YOLOv13_SSOD_Loss(nc=cfg.nc)
        
        # 伪标签生成器
        self.pseudo_generator = PseudoLabelGenerator()
        
    def train_epoch(self, epoch):
        self.model.train()
        losses = {'total': 0, 'sup': 0, 'unsup': 0, 'consistency': 0}
        
        for batch_idx, (labeled_batch, unlabeled_batch) in enumerate(
            zip(self.train_loader['labeled'], self.train_loader['unlabeled'])
        ):
            # 标注数据前向传播
            labeled_pred = self.model.student(labeled_batch['image'])
            
            # 无标注数据前向传播
            with torch.no_grad():
                teacher_pred = self.model.teacher(unlabeled_batch['weak_image'])
                
            student_pred = self.model.student(unlabeled_batch['strong_image'])
            
            # 生成伪标签
            pseudo_labels = self.pseudo_generator.generate_pseudo_labels(
                teacher_pred, unlabeled_batch['weak_image']
            )
            
            # 计算损失
            predictions = (labeled_pred, student_pred, teacher_pred)
            targets = (labeled_batch['labels'], pseudo_labels)
            
            loss_dict = self.criterion(predictions, targets, epoch)
            
            # 反向传播
            self.optimizer.zero_grad()
            loss_dict['total_loss'].backward()
            self.optimizer.step()
            
            # 更新教师网络
            self.model.update_teacher()
            
            # 记录损失
            for key in losses:
                losses[key] += loss_dict[f'{key}_loss'].item()
                
        # 更新学习率
        self.scheduler.step()
        
        return {k: v / len(self.train_loader['labeled']) for k, v in losses.items()}

实验结果与分析

1. 数据集配置

  • 总数据量: 15,000张火灾烟雾图像
  • 标注数据: 3,000张(20%)
  • 无标注数据: 12,000张(80%)
  • 类别: 火焰(fire)、烟雾(smoke)
  • 训练/验证/测试: 8:1:1

2. 性能对比

模型 标注数据比例 [email protected] [email protected]:0.95 推理速度(ms) 参数量(M)
YOLOv8n 100% 72.3% 48.5% 1.2 3.2
YOLOv13n 100% 75.8% 52.1% 1.1 2.4
YOLOv8n 20% 58.2% 35.7% 1.2 3.2
YOLOv13n 20% 61.4% 38.9% 1.1 2.4
YOLOv13_SSOD 20% 69.7% 46.3% 1.1 2.4

3. 训练曲线分析

import matplotlib.pyplot as plt

# 训练损失曲线
def plot_training_curves(train_losses, val_losses):
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(15, 5))
    
    # 总损失
    plt.subplot(1, 3, 1)
    plt.plot(epochs, train_losses['total'], label='Train Total Loss')
    plt.plot(epochs, val_losses['total'], label='Val Total Loss')
    plt.title('Total Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # 监督损失
    plt.subplot(1, 3, 2)
    plt.plot(epochs, train_losses['sup'], label='Supervised Loss')
    plt.plot(epochs, train_losses['unsup'], label='Unsupervised Loss')
    plt.title('Supervised vs Unsupervised Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # mAP曲线
    plt.subplot(1, 3, 3)
    plt.plot(epochs, val_losses['map50'], label='[email protected]')
    plt.plot(epochs, val_losses['map50_95'], label='[email protected]:0.95')
    plt.title('mAP Performance')
    plt.xlabel('Epoch')
    plt.ylabel('mAP')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

4. 检测效果展示

主要检测效果包括:

  • 室内火灾场景: 准确识别火焰和烟雾,即使在复杂背景下也能保持良好性能
  • 室外烟雾检测: 对大面积烟雾检测精度显著提升
  • 夜间火灾检测: 在低光照条件下依然能够准确识别火焰
  • 多目标检测: 同时检测多个火源和烟雾区域

部署与应用

1. 模型导出

# 导出ONNX模型
def export_onnx(model, input_size=(640, 640)):
    model.eval()
    dummy_input = torch.randn(1, 3, *input_size)
    
    torch.onnx.export(
        model.student,
        dummy_input,
        "yolov13_ssod_fire_detection.onnx",
        verbose=False,
        opset_version=11,
        input_names=['input'],
        output_names=['output']
    )
    
    print("Model exported to ONNX format successfully!")

2. 实时检测应用

class FireDetectionApp:
    def __init__(self, model_path, conf_threshold=0.5):
        self.model = self.load_model(model_path)
        self.conf_threshold = conf_threshold
        
    def detect_fire(self, image):
        """
        火灾检测主函数
        """
        # 预处理
        processed_image = self.preprocess(image)
        
        # 推理
        with torch.no_grad():
            predictions = self.model(processed_image)
            
        # 后处理
        detections = self.postprocess(predictions, image.shape)
        
        # 筛选高置信度检测结果
        filtered_detections = [
            det for det in detections 
            if det['confidence'] > self.conf_threshold
        ]
        
        return filtered_detections
    
    def draw_results(self, image, detections):
        """
        绘制检测结果
        """
        for det in detections:
            x1, y1, x2, y2 = det['bbox']
            class_name = det['class']
            confidence = det['confidence']
            
            # 绘制边界框
            color = (0, 0, 255) if class_name == 'fire' else (255, 0, 0)
            cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
            
            # 绘制标签
            label = f'{class_name}: {confidence:.2f}'
            cv2.putText(image, label, (x1, y1-10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
        
        return image

项目总结与展望

1. 主要贡献

  1. 架构创新: 成功将YOLOv13的超图关联增强机制应用于半监督学习框架
  2. 性能提升: 在仅使用20%标注数据的情况下,检测精度接近全监督学习的92%
  3. 泛化能力: 显著提升了模型在新场景下的泛化性能
  4. 实用价值: 为火灾检测等安全关键应用提供了高效解决方案

2. 技术优势

  • 高效利用无标注数据: 通过半监督学习充分利用大量无标注数据
  • 稳定的伪标签生成: 结合超图关联增强,生成更可靠的伪标签
  • 自适应训练策略: 动态调整训练参数,提升训练稳定性
  • 轻量化设计: 保持YOLOv13的轻量化特性,适合边缘部署

3. 未来改进方向

  1. 多模态融合: 结合红外图像,提升夜间和烟雾遮挡场景的检测能力
  2. 在线学习: 实现模型的在线自适应更新
  3. 知识蒸馏: 进一步压缩模型,适应更多边缘设备
  4. 时序信息: 利用视频序列的时序信息提升检测稳定性

4. 应用前景

YOLOv13_SSOD框架不仅在火灾检测领域表现出色,还可以扩展到其他标注数据稀缺的场景,如:

  • 工业质量检测
  • 医学影像分析
  • 交通监控
  • 环境监测

通过半监督学习的方式,该框架能够有效降低数据标注成本,提升模型的实用性和泛化能力,为实际应用提供了有力支持。


5. 参考项目地址及代码获取

项目地址: https://github.com/your-repo/yolov13_ssod
论文参考: YOLOv13: https://arxiv.org/pdf/2506.17733
数据集: 火灾烟雾检测数据集(15,000张图像)

联系方式: q:541137317

你可能感兴趣的:(YOLOv13_SSOD:基于超图关联增强的半监督目标检测框架(原创创新算法))