【Block总结】DCAFE,并行双坐标注意力机制,增强长程依赖与抗噪性|即插即用

论文信息

  • 标题:Flora-NET: Integrating dual coordinate attention with adaptive kernel based convolution network for medicinal flower identification
  • 中文标题:Flora-NET:集成双坐标注意力机制与自适应内核的药用花卉识别卷积网络
  • 链接:ScienceDirect
  • 代码链接:https://github.com/ersachingupta11/Flora-NET
  • 单位:印度马拉维亚国立理工学院计算机科学与工程系

创新点

  1. 双模块协同架构
    • DCAFE模块:并行双坐标注意力机制(平均池化+最大池化),增强长程依赖与抗噪性,保留更多上下文信息。
    • Inv-FR模块:串行自适应卷积核操作,优化空间特征表示。
  2. 领域专用设计:针对药用花卉的类内差异大、类间差异小、环境干扰(遮挡/光照/背景复杂)等难点优化。
  3. 轻量化改进:双坐标注意力机制避免显著增加计算成本。

【Block总结】DCAFE,并行双坐标注意力机制,增强长程依赖与抗噪性|即插即用_第1张图片

方法

  1. 整体流程
    • 输入:224×224 RGB图像(经尺寸调整+数据增强)。
    • 特征提取:CBS模块(卷积-批量归一化-SiLU)为基础,嵌入DCAFE与Inv-FR模块。
    • 分类:全连接层输出类别概率。
  2. 核心模块
    • DCAFE
      • 并行分支:平均池化坐标注意力(fca) + 最大池化坐标注意力(fcm)。
      • 输出:连接两分支结果,兼顾全局上下文与细节鲁棒性。
    • Inv-FR
      • 自适应卷积核动态调整空间特征,提升对形变/遮挡的适应性。

效果

  1. 定量性能

    数据集 准确率 超越SOTA
    Urban Street (17类) 91.12% +6.94%
    Medicinal Blossom (12类) 91.18% +5.71%
  2. 消融实验验证

    • 移除DCAFE+Inv-FR:准确率降至82.57%(Urban)和84.39%(Blossom)。
    • 证明双模块对性能提升的关键作用。
  3. 可视化证据

    • CAM图(图4):DCAFE与Inv-FR模块聚焦花瓣纹理等判别区域。
    • t-SNE(图6):处理后特征聚类分离度显著提升。
    • 混淆矩阵(图8):各类别错误率均衡,无显著偏差。

总结

  1. 贡献
    • 提出首个结合双坐标注意力与自适应卷积核的药用花卉分类模型,显著解决类内差异大、环境干扰强的问题。
    • 在公开数据集上达到**>91%准确率**,大幅领先现有方法。
  2. 应用价值
    • 医药制造:精准识别药用植物成分,保障药物生产质量。
    • 生物多样性保护:自动化监测濒危花卉物种。
  3. 局限与未来
    • 局限:对严重遮挡图像敏感,类别不平衡场景待优化。
    • 方向:融合GANs生成合成数据,探索多模态(如纹理+光谱)信息。

代码

import torch
import torch.nn as nn
import math
import torch.nn.functional as F

class CoordAttMeanMax(nn.Module):
    def __init__(self, inp, oup, groups=32):
        super(CoordAttMeanMax, self).__init__()
        self.pool_h_mean = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w_mean = nn.AdaptiveAvgPool2d((1, None))
        self.pool_h_max = nn.AdaptiveMaxPool2d((None, 1))
        self.pool_w_max = nn.AdaptiveMaxPool2d((1, None))

        mip = max(8, inp // groups)

        self.conv1_mean = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1_mean = nn.BatchNorm2d(mip)
        self.conv2_mean = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

        self.conv1_max = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1_max = nn.BatchNorm2d(mip)
        self.conv2_max = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        n, c, h, w = x.size()

        # Mean pooling branch
        x_h_mean = self.pool_h_mean(x)
        x_w_mean = self.pool_w_mean(x).permute(0, 1, 3, 2)
        y_mean = torch.cat([x_h_mean, x_w_mean], dim=2)
        y_mean = self.conv1_mean(y_mean)
        y_mean = self.bn1_mean(y_mean)
        y_mean = self.relu(y_mean)
        x_h_mean, x_w_mean = torch.split(y_mean, [h, w], dim=2)
        x_w_mean = x_w_mean.permute(0, 1, 3, 2)

        # Max pooling branch
        x_h_max = self.pool_h_max(x)
        x_w_max = self.pool_w_max(x).permute(0, 1, 3, 2)
        y_max = torch.cat([x_h_max, x_w_max], dim=2)
        y_max = self.conv1_max(y_max)
        y_max = self.bn1_max(y_max)
        y_max = self.relu(y_max)
        x_h_max, x_w_max = torch.split(y_max, [h, w], dim=2)
        x_w_max = x_w_max.permute(0, 1, 3, 2)

        # Apply attention
        x_h_mean = self.conv2_mean(x_h_mean).sigmoid()
        x_w_mean = self.conv2_mean(x_w_mean).sigmoid()
        x_h_max = self.conv2_max(x_h_max).sigmoid()
        x_w_max = self.conv2_max(x_w_max).sigmoid()

        # Expand to original shape
        x_h_mean = x_h_mean.expand(-1, -1, h, w)
        x_w_mean = x_w_mean.expand(-1, -1, h, w)
        x_h_max = x_h_max.expand(-1, -1, h, w)
        x_w_max = x_w_max.expand(-1, -1, h, w)

        # Combine outputs
        attention_mean = identity * x_w_mean * x_h_mean
        attention_max = identity * x_w_max * x_h_max

        # Sum the attention outputs
        return attention_mean + attention_max

class DCAFE(nn.Module):
    def __init__(self, in_channels):
        super(DCAFE, self).__init__()
        self.coord_att = CoordAttMeanMax(in_channels, oup=in_channels)

    def forward(self, x):
        return self.coord_att(x)



if __name__ == "__main__":
    # 定义输入张量大小(Batch、Channel、Height、Wight)
    B, C, H, W = 16, 64, 40, 40
    input_tensor = torch.randn(B,C,H,W)  # 随机生成输入张量
    dim=C
    # 创建 ARConv 实例
    block = DCAFE(in_channels=64)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sablock = block.to(device)
    print(sablock)
    input_tensor = input_tensor.to(device)
    # 执行前向传播
    output = sablock(input_tensor)
    # 打印输入和输出的形状
    print(f"Input: {input_tensor.shape}")
    print(f"Output: {output.shape}")

运行结果:
【Block总结】DCAFE,并行双坐标注意力机制,增强长程依赖与抗噪性|即插即用_第2张图片

代码详解

import torch
import torch.nn as nn
import math
import torch.nn.functional as F

class CoordAttMeanMax(nn.Module):
    # 实现双坐标注意力机制的核心类
    def __init__(self, inp, oup, groups=32):
        super(CoordAttMeanMax, self).__init__()
        
        # 池化层定义
        self.pool_h_mean = nn.AdaptiveAvgPool2d((None, 1))  # 水平方向平均池化
        self.pool_w_mean = nn.AdaptiveAvgPool2d((1, None))  # 垂直方向平均池化
        self.pool_h_max = nn.AdaptiveMaxPool2d((None, 1))   # 水平方向最大池化
        self.pool_w_max = nn.AdaptiveMaxPool2d((1, None))   # 垂直方向最大池化
        
        # 中间层通道数计算(至少为8)
        mip = max(8, inp // groups)
        
        # 平均池化分支的卷积层
        self.conv1_mean = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1_mean = nn.BatchNorm2d(mip)
        self.conv2_mean = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        
        # 最大池化分支的卷积层
        self.conv1_max = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1_max = nn.BatchNorm2d(mip)
        self.conv2_max = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x  # 保留原始输入
        n, c, h, w = x.size()
        
        # ===== 平均池化分支处理 =====
        # 水平方向池化 (n, c, h, w) -> (n, c, h, 1)
        x_h_mean = self.pool_h_mean(x)
        # 垂直方向池化 (n, c, h, w) -> (n, c, 1, w) -> 转置为 (n, c, w, 1)
        x_w_mean = self.pool_w_mean(x).permute(0, 1, 3, 2)
        
        # 拼接水平和垂直特征
        y_mean = torch.cat([x_h_mean, x_w_mean], dim=2)  # (n, c, h+w, 1)
        
        # 特征转换
        y_mean = self.conv1_mean(y_mean)
        y_mean = self.bn1_mean(y_mean)
        y_mean = self.relu(y_mean)
        
        # 分离特征
        x_h_mean, x_w_mean = torch.split(y_mean, [h, w], dim=2)
        x_w_mean = x_w_mean.permute(0, 1, 3, 2)  # 转置回 (n, c, 1, w)
        
        # ===== 最大池化分支处理 =====(流程同上)
        x_h_max = self.pool_h_max(x)
        x_w_max = self.pool_w_max(x).permute(0, 1, 3, 2)
        y_max = torch.cat([x_h_max, x_w_max], dim=2)
        y_max = self.conv1_max(y_max)
        y_max = self.bn1_max(y_max)
        y_max = self.relu(y_max)
        x_h_max, x_w_max = torch.split(y_max, [h, w], dim=2)
        x_w_max = x_w_max.permute(0, 1, 3, 2)
        
        # ===== 注意力权重生成 =====
        # 平均分支的注意力权重
        x_h_mean = self.conv2_mean(x_h_mean).sigmoid()  # (n, oup, h, 1)
        x_w_mean = self.conv2_mean(x_w_mean).sigmoid()  # (n, oup, 1, w)
        
        # 最大分支的注意力权重
        x_h_max = self.conv2_max(x_h_max).sigmoid()     # (n, oup, h, 1)
        x_w_max = self.conv2_max(x_w_max).sigmoid()     # (n, oup, 1, w)
        
        # ===== 注意力图扩展 =====
        # 将1D注意力扩展到2D空间
        x_h_mean = x_h_mean.expand(-1, -1, h, w)  # (n, oup, h, w)
        x_w_mean = x_w_mean.expand(-1, -1, h, w)  # (n, oup, h, w)
        x_h_max = x_h_max.expand(-1, -1, h, w)    # (n, oup, h, w)
        x_w_max = x_w_max.expand(-1, -1, h, w)    # (n, oup, h, w)
        
        # ===== 注意力应用 =====
        # 平均分支加权
        attention_mean = identity * x_w_mean * x_h_mean
        # 最大分支加权
        attention_max = identity * x_w_max * x_h_max
        
        # 双分支结果融合
        return attention_mean + attention_max

class DCAFE(nn.Module):
    # 双坐标注意力特征提取模块(论文中的DCAFE)
    def __init__(self, in_channels):
        super(DCAFE, self).__init__()
        self.coord_att = CoordAttMeanMax(in_channels, oup=in_channels)
    
    def forward(self, x):
        return self.coord_att(x)

关键机制解析

  1. 双分支结构

    • 平均池化分支:捕获全局上下文信息,对噪声更鲁棒
    • 最大池化分支:聚焦显著特征,增强判别能力
    • 设计理念:同时利用两种池化的优势,兼顾全局信息和局部显著特征
  2. 坐标注意力机制

    # 水平方向处理
    x_h = pool_h(x)  # 形状: (n, c, h, 1)
    
    # 垂直方向处理
    x_w = pool_w(x).permute(0, 1, 3, 2)  # 形状: (n, c, w, 1)
    
    # 特征拼接
    y = torch.cat([x_h, x_w], dim=2)  # 形状: (n, c, h+w, 1)
    
    • 将空间维度分解为水平和垂直两个方向
    • 分别进行特征聚合后拼接,保留位置信息
    • 符合论文描述:“结合位置特征空间到通道注意力机制中”
  3. 轻量化设计

    mip = max(8, inp // groups)  # 通道压缩
    conv1 = nn.Conv2d(inp, mip, kernel_size=1)  # 1×1卷积降维
    
    • 使用1×1卷积压缩通道数(mip计算)
    • 减少计算量同时保持表达能力
    • 符合论文"避免额外计算成本"的设计目标
  4. 注意力生成与应用

    # 生成空间注意力权重
    att = conv2(features).sigmoid()  # 值域[0,1]
    
    # 扩展为完整空间尺寸
    att = att.expand(-1, -1, h, w)
    
    # 应用注意力
    output = input * att_w * att_h  # 双方向注意力相乘
    
    • 通过Sigmoid生成空间注意力图
    • 水平/垂直注意力相乘实现交叉增强
    • 最终双分支结果相加融合

模块特性总结

  1. 空间感知能力

    • 显式建模水平/垂直方向关系
    • 精确定位特征空间位置
    • 增强对花卉纹理/结构的捕捉能力
  2. 多尺度融合

    输入特征
    平均池化分支
    最大池化分支
    全局上下文
    局部显著特征
    特征融合
    输出
    • 双分支互补提供多尺度特征表示
    • 特别适合处理类内差异大的药用花卉
  3. 计算效率

    • 通道压缩减少参数量
    • 1×1卷积降低计算复杂度
    • 适合嵌入式/移动端部署
  4. 注意力可视化

    • 生成的注意力图可直接解释
    • 对应论文图4的可视化效果
    • 展示模型关注的花卉关键区域

实现亮点

  1. 位置敏感注意力

    • 通过坐标分离保留空间信息
    • 解决传统通道注意力丢失位置信息的问题
  2. 双模态融合

    • 同时利用平均池化和最大池化的优势
    • 平均池化:背景上下文信息
    • 最大池化:前景显著特征
  3. 即插即用设计

    # 在CNN中的典型用法
    class CNNBlock(nn.Module):
        def __init__(self, in_ch):
            super().__init__()
            self.conv = nn.Conv2d(in_ch, out_ch, 3)
            self.dcafe = DCAFE(out_ch)  # 嵌入DCAFE模块
            
        def forward(self, x):
            x = self.conv(x)
            x = self.dcafe(x)  # 增强特征表达
            return x
    
    • 可无缝集成到现有CNN架构中
    • 在YOLO、ResNet等网络中均可使用

你可能感兴趣的:(人工智能,计算机视觉,YOLO,目标检测)