fca
) + 最大池化坐标注意力(fcm
)。定量性能
数据集 | 准确率 | 超越SOTA |
---|---|---|
Urban Street (17类) | 91.12% | +6.94% |
Medicinal Blossom (12类) | 91.18% | +5.71% |
消融实验验证
可视化证据
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class CoordAttMeanMax(nn.Module):
def __init__(self, inp, oup, groups=32):
super(CoordAttMeanMax, self).__init__()
self.pool_h_mean = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w_mean = nn.AdaptiveAvgPool2d((1, None))
self.pool_h_max = nn.AdaptiveMaxPool2d((None, 1))
self.pool_w_max = nn.AdaptiveMaxPool2d((1, None))
mip = max(8, inp // groups)
self.conv1_mean = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1_mean = nn.BatchNorm2d(mip)
self.conv2_mean = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv1_max = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1_max = nn.BatchNorm2d(mip)
self.conv2_max = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
n, c, h, w = x.size()
# Mean pooling branch
x_h_mean = self.pool_h_mean(x)
x_w_mean = self.pool_w_mean(x).permute(0, 1, 3, 2)
y_mean = torch.cat([x_h_mean, x_w_mean], dim=2)
y_mean = self.conv1_mean(y_mean)
y_mean = self.bn1_mean(y_mean)
y_mean = self.relu(y_mean)
x_h_mean, x_w_mean = torch.split(y_mean, [h, w], dim=2)
x_w_mean = x_w_mean.permute(0, 1, 3, 2)
# Max pooling branch
x_h_max = self.pool_h_max(x)
x_w_max = self.pool_w_max(x).permute(0, 1, 3, 2)
y_max = torch.cat([x_h_max, x_w_max], dim=2)
y_max = self.conv1_max(y_max)
y_max = self.bn1_max(y_max)
y_max = self.relu(y_max)
x_h_max, x_w_max = torch.split(y_max, [h, w], dim=2)
x_w_max = x_w_max.permute(0, 1, 3, 2)
# Apply attention
x_h_mean = self.conv2_mean(x_h_mean).sigmoid()
x_w_mean = self.conv2_mean(x_w_mean).sigmoid()
x_h_max = self.conv2_max(x_h_max).sigmoid()
x_w_max = self.conv2_max(x_w_max).sigmoid()
# Expand to original shape
x_h_mean = x_h_mean.expand(-1, -1, h, w)
x_w_mean = x_w_mean.expand(-1, -1, h, w)
x_h_max = x_h_max.expand(-1, -1, h, w)
x_w_max = x_w_max.expand(-1, -1, h, w)
# Combine outputs
attention_mean = identity * x_w_mean * x_h_mean
attention_max = identity * x_w_max * x_h_max
# Sum the attention outputs
return attention_mean + attention_max
class DCAFE(nn.Module):
def __init__(self, in_channels):
super(DCAFE, self).__init__()
self.coord_att = CoordAttMeanMax(in_channels, oup=in_channels)
def forward(self, x):
return self.coord_att(x)
if __name__ == "__main__":
# 定义输入张量大小(Batch、Channel、Height、Wight)
B, C, H, W = 16, 64, 40, 40
input_tensor = torch.randn(B,C,H,W) # 随机生成输入张量
dim=C
# 创建 ARConv 实例
block = DCAFE(in_channels=64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sablock = block.to(device)
print(sablock)
input_tensor = input_tensor.to(device)
# 执行前向传播
output = sablock(input_tensor)
# 打印输入和输出的形状
print(f"Input: {input_tensor.shape}")
print(f"Output: {output.shape}")
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class CoordAttMeanMax(nn.Module):
# 实现双坐标注意力机制的核心类
def __init__(self, inp, oup, groups=32):
super(CoordAttMeanMax, self).__init__()
# 池化层定义
self.pool_h_mean = nn.AdaptiveAvgPool2d((None, 1)) # 水平方向平均池化
self.pool_w_mean = nn.AdaptiveAvgPool2d((1, None)) # 垂直方向平均池化
self.pool_h_max = nn.AdaptiveMaxPool2d((None, 1)) # 水平方向最大池化
self.pool_w_max = nn.AdaptiveMaxPool2d((1, None)) # 垂直方向最大池化
# 中间层通道数计算(至少为8)
mip = max(8, inp // groups)
# 平均池化分支的卷积层
self.conv1_mean = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1_mean = nn.BatchNorm2d(mip)
self.conv2_mean = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
# 最大池化分支的卷积层
self.conv1_max = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1_max = nn.BatchNorm2d(mip)
self.conv2_max = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x # 保留原始输入
n, c, h, w = x.size()
# ===== 平均池化分支处理 =====
# 水平方向池化 (n, c, h, w) -> (n, c, h, 1)
x_h_mean = self.pool_h_mean(x)
# 垂直方向池化 (n, c, h, w) -> (n, c, 1, w) -> 转置为 (n, c, w, 1)
x_w_mean = self.pool_w_mean(x).permute(0, 1, 3, 2)
# 拼接水平和垂直特征
y_mean = torch.cat([x_h_mean, x_w_mean], dim=2) # (n, c, h+w, 1)
# 特征转换
y_mean = self.conv1_mean(y_mean)
y_mean = self.bn1_mean(y_mean)
y_mean = self.relu(y_mean)
# 分离特征
x_h_mean, x_w_mean = torch.split(y_mean, [h, w], dim=2)
x_w_mean = x_w_mean.permute(0, 1, 3, 2) # 转置回 (n, c, 1, w)
# ===== 最大池化分支处理 =====(流程同上)
x_h_max = self.pool_h_max(x)
x_w_max = self.pool_w_max(x).permute(0, 1, 3, 2)
y_max = torch.cat([x_h_max, x_w_max], dim=2)
y_max = self.conv1_max(y_max)
y_max = self.bn1_max(y_max)
y_max = self.relu(y_max)
x_h_max, x_w_max = torch.split(y_max, [h, w], dim=2)
x_w_max = x_w_max.permute(0, 1, 3, 2)
# ===== 注意力权重生成 =====
# 平均分支的注意力权重
x_h_mean = self.conv2_mean(x_h_mean).sigmoid() # (n, oup, h, 1)
x_w_mean = self.conv2_mean(x_w_mean).sigmoid() # (n, oup, 1, w)
# 最大分支的注意力权重
x_h_max = self.conv2_max(x_h_max).sigmoid() # (n, oup, h, 1)
x_w_max = self.conv2_max(x_w_max).sigmoid() # (n, oup, 1, w)
# ===== 注意力图扩展 =====
# 将1D注意力扩展到2D空间
x_h_mean = x_h_mean.expand(-1, -1, h, w) # (n, oup, h, w)
x_w_mean = x_w_mean.expand(-1, -1, h, w) # (n, oup, h, w)
x_h_max = x_h_max.expand(-1, -1, h, w) # (n, oup, h, w)
x_w_max = x_w_max.expand(-1, -1, h, w) # (n, oup, h, w)
# ===== 注意力应用 =====
# 平均分支加权
attention_mean = identity * x_w_mean * x_h_mean
# 最大分支加权
attention_max = identity * x_w_max * x_h_max
# 双分支结果融合
return attention_mean + attention_max
class DCAFE(nn.Module):
# 双坐标注意力特征提取模块(论文中的DCAFE)
def __init__(self, in_channels):
super(DCAFE, self).__init__()
self.coord_att = CoordAttMeanMax(in_channels, oup=in_channels)
def forward(self, x):
return self.coord_att(x)
双分支结构:
坐标注意力机制:
# 水平方向处理
x_h = pool_h(x) # 形状: (n, c, h, 1)
# 垂直方向处理
x_w = pool_w(x).permute(0, 1, 3, 2) # 形状: (n, c, w, 1)
# 特征拼接
y = torch.cat([x_h, x_w], dim=2) # 形状: (n, c, h+w, 1)
轻量化设计:
mip = max(8, inp // groups) # 通道压缩
conv1 = nn.Conv2d(inp, mip, kernel_size=1) # 1×1卷积降维
注意力生成与应用:
# 生成空间注意力权重
att = conv2(features).sigmoid() # 值域[0,1]
# 扩展为完整空间尺寸
att = att.expand(-1, -1, h, w)
# 应用注意力
output = input * att_w * att_h # 双方向注意力相乘
空间感知能力:
多尺度融合:
计算效率:
注意力可视化:
位置敏感注意力:
双模态融合:
即插即用设计:
# 在CNN中的典型用法
class CNNBlock(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3)
self.dcafe = DCAFE(out_ch) # 嵌入DCAFE模块
def forward(self, x):
x = self.conv(x)
x = self.dcafe(x) # 增强特征表达
return x