知识点回顾:
CBAM 注意力模块介绍
从 SE 到 CBAM:注意力机制的演进
之前我们介绍了 SE(Squeeze-and-Excitation)通道注意力模块,其本质是对特征进行增强处理。现在,我们进一步探讨一种更强大的注意力机制:CBAM(Convolutional Block Attention Module)。
CBAM 的核心目标与价值
CBAM 是一种轻量级、可即插即用的注意力模块,能够无缝集成到各种卷积神经网络(CNN)架构中。它的核心原理是让模型自动学习特征图在通道和空间两个维度上的重要性权重,并据此对特征图进行自适应的调整——强化关键特征,弱化非关键特征。这一过程显著提升了模型的特征表征能力和最终性能。形象地说,CBAM 如同为模型配备了“智能眼镜”,使其能更精准地聚焦于输入图像中最具信息量的部分。
CBAM 的组成结构:通道与空间注意力的结合
CBAM 由两个顺序连接的子模块构成:
通道注意力模块(Channel Attention Module):分析并确定哪些特征通道包含更关键的信息(例如,识别与特定物体颜色或纹理相关的通道)。
空间注意力模块(Spatial Attention Module):定位并确定关键信息在特征图空间维度上的具体位置(例如,识别物体在图像中的区域)。
处理流程为:输入特征图 → 通道注意力模块 → 空间注意力模块 → 输出增强后的特征图。
相较于 SE 模块的突破
SE 的局限:SE 仅聚焦于学习“哪些通道重要”,而忽略了“重要信息在空间中的具体位置”。
CBAM 的突破:
通过通道注意力明确“关注哪些特征通道”。
通过空间注意力定位“在空间中的哪个位置关注”。
二者协同工作,使模型能够同时学习“关注什么内容”和“在何处关注”,从而实现了更全面、更强大的特征表达能力提升。
CBAM 的优势特点
轻量高效:仅通过全局池化操作和少量简单卷积层实现,计算开销小。
即插即用:无需改动原有网络的主体结构,可方便地插入卷积层之间。
双重优化:同时提升特征在通道维度和空间维度的质量,在处理复杂任务(如小目标检测、语义分割)时效果尤为显著。
通用性强:与 SE 等注意力模块类似,CBAM 也属于“即插即用”型模块,适用于广泛的 CNN 架构(如 ResNet, YOLO 等)。
import torch
import torch.nn as nn
class ChannelAttention(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
"""
通道注意力机制初始化
参数:
in_channels: 输入特征通道数
reduction_ratio: 通道压缩比例,默认16
"""
super().__init__()
# 特征压缩操作:使用自适应池化获取通道统计信息
self.channel_avg = nn.AdaptiveAvgPool2d(1) # 提取通道均值特征
self.channel_max = nn.AdaptiveMaxPool2d(1) # 提取通道峰值特征
# 通道关系学习模块:共享权重的多层感知机
self.channel_mlp = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
nn.ReLU(inplace=True), # 原位激活减少内存占用
nn.Linear(in_channels // reduction_ratio, in_channels, bias=False)
)
# 权重归一化层
self.weight_activ = nn.Sigmoid()
def forward(self, feature_map):
"""
特征图通道增强处理
参数:
feature_map: 输入特征张量 [B, C, H, W]
返回:
通道增强后的特征张量
"""
batch_size, num_channels, _, _ = feature_map.size()
# 通道统计信息计算
avg_stats = self.channel_avg(feature_map).view(batch_size, num_channels)
max_stats = self.channel_max(feature_map).view(batch_size, num_channels)
# 双路特征融合
channel_weights = self.channel_mlp(avg_stats) + self.channel_mlp(max_stats)
# 生成通道注意力掩码
attention_mask = self.weight_activ(channel_weights).view(batch_size, num_channels, 1, 1)
# 应用注意力权重
return feature_map * attention_mask # 广播机制自动扩展维度
import torch
import torch.nn as nn
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
"""
空间注意力机制初始化
参数:
kernel_size: 卷积核尺寸,控制感受野大小(必须为奇数)
"""
super().__init__()
# 特征融合卷积层:将双通道特征图融合为单通道空间权重
# 使用奇数卷积核确保对称填充 (padding=kernel_size//2)
self.spatial_conv = nn.Conv2d(
in_channels=2,
out_channels=1,
kernel_size=kernel_size,
padding=kernel_size // 2, # 保持特征图尺寸不变
bias=False # 无偏置项
)
# 空间权重归一化
self.weight_activ = nn.Sigmoid()
def forward(self, feature_map):
"""
特征图空间增强处理
参数:
feature_map: 输入特征张量 [B, C, H, W]
返回:
空间增强后的特征张量
"""
# 通道维度压缩:生成空间显著性特征
# 平均池化:捕获全局上下文信息 [B, 1, H, W]
channel_avg = torch.mean(feature_map, dim=1, keepdim=True)
# 最大池化:捕获显著局部特征 [B, 1, H, W]
channel_max, _ = torch.max(feature_map, dim=1, keepdim=True)
# 双通道特征拼接:融合不同视角的空间信息 [B, 2, H, W]
spatial_stats = torch.cat([channel_avg, channel_max], dim=1)
# 空间特征学习:通过卷积生成注意力热图
attention_heatmap = self.spatial_conv(spatial_stats)
# 生成空间注意力掩码 [B, 1, H, W]
spatial_mask = self.weight_activ(attention_heatmap)
# 应用空间注意力权重
return feature_map * spatial_mask # 广播机制自动扩展通道维度
@浙大疏锦行