19 - SAFM模块

论文《Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution》

1、作用

这篇论文通过提出空间自适应特征调制(Spatially-Adaptive Feature Modulation, SAFM)机制,旨在解决图像超分辨率(Super-Resolution, SR)的高效设计问题。在图像超分辨率重建性能上取得了显著的成果,这些模型通常具有大型复杂的架构,不适用于低功耗设备,限于计算和存储资源。SAFM层通过独立计算学习多尺度特征表示,并动态聚合这些特征进行空间调制,克服了这些挑战。

2、机制

1、空间自适应特征调制(SAFM)层

SAFM层利用多尺度特征表示独立学习,并动态进行空间调制。SAFM着重于利用非局部特征依赖性,进一步引入卷积通道混合器(Convolutional Channel Mixer, CCM),以编码局部上下文信息并同时混合通道。

2、卷积通道混合器(CCM)

为了补充局部上下文信息,提出了基于FMBConv的CCM,用于编码局部特征并混合通道,增强了模型处理特征的能力。

3、独特优势

1、高效性和灵活性

SAFMN模型相比于现有的高效SR方法小3倍,如IMDN等,同时以更少的内存使用实现了可比的性能。

2、动态空间调制

通过利用多尺度特征表示进行动态空间调制,SAFMN能够高效地聚合特征,提升重建性能,同时保持低计算和存储成本。

3、局部和非局部特征的有效整合

通过SAFM层和CCM的结合,SAFMN有效整合了局部和非局部特征信息,实现了更精准的图像超分辨率重建。

4、代码

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义SAFM类,继承自nn.Module
class SAFM(nn.Module):
    def __init__(self, dim, n_levels=4):
        super().__init__()
        # n_levels表示特征会被分割成多少个不同的尺度
        self.n_levels = n_levels
        # 每个尺度的特征通道数
        chunk_dim = dim // n_levels

        # Spatial Weighting:针对每个尺度的特征,使用深度卷积进行空间加权
        self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])

        # Feature Aggregation:用于聚合不同尺度处理过的特征
        self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)

        # Activation:使用GELU激活函数
        self.act = nn.GELU()

    def forward(self, x):
        # x的形状为(B,C,H,W),其中B是批次大小,C是通道数,H和W是高和宽
        h, w = x.size()[-2:]

        # 将输入特征在通道维度上分割成n_levels个尺度
        xc = x.chunk(self.n_levels, dim=1)

        out = []
        for i in range(self.n_levels):
            if i > 0:
                # 计算每个尺度下采样后的大小
                p_size = (h // 2**i, w // 2**i)
                # 对特征进行自适应最大池化,降低分辨率
                s = F.adaptive_max_pool2d(xc[i], p_size)
                # 对降低分辨率的特征应用深度卷积
                s = self.mfr[i](s)
                # 使用最近邻插值将特征上采样到原始大小
                s = F.interpolate(s, size=(h, w), mode='nearest')
            else:
                # 第一尺度直接应用深度卷积,不进行下采样
                s = self.mfr[i](xc[i])
            out.append(s)

        # 将处理过的所有尺度的特征在通道维度上进行拼接
        out = torch.cat(out, dim=1)
        # 通过1x1卷积聚合拼接后的特征
        out = self.aggr(out)
        # 应用GELU激活函数并与原始输入相乘,实现特征调制
        out = self.act(out) * x
        return out

if __name__ == '__main__':
    # 创建一个SAFM实例并对一个随机输入进行处理
    x = torch.randn(1, 36, 224, 224)
    Model = SAFM(dim=36)
    out = Model(x)
    print(out.shape)

你可能感兴趣的:(深度学习算法实现,深度学习,计算机视觉,机器学习)