【CVPR 2025】本文参考论文RT-X NET: RGB-THERMAL CROSS ATTENTION NETWORK FOR LOW-LIGHT IMAGE ENHANCEMENT Transformer Design
下面对方法论部分进行详细分析
论文提出网络类似Retinexformer / SG-LLIE,感兴趣的可以在我的主页进一步学习
论文地址:arxiv
代码地址:github
import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
import math
import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
from pdb import set_trace as stx
from sklearn.decomposition import PCA
# import cv2
# def apply_pca(features, desired_channels):
# batch_size, height, width, channels = features.shape
# # Flatten the features and reshape to 2D array for PCA
# flattened_features = features.permute(0, 2, 3, 1).contiguous().view(-1, channels)
# # Convert to numpy and perform PCA
# pca = PCA(n_components=desired_channels)
# reduced_features_np = pca.fit_transform(flattened_features.detach().cpu().numpy())
# print("Reduced Features after applying PCA: ", reduced_features_np.shape)
# # Convert back to tensor and reshape
# reduced_features = torch.from_numpy(reduced_features_np).to(features.device)
# print("Reduced Features after convertng to numpy: ", reduced_features.shape)
# reduced_features = reduced_features.view(batch_size, height, width, desired_channels).permute(0, 3, 1, 2)
# print("Final Reduced Features: ", reduced_features.shape)
# return reduced_features
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
import
语句: 导入了构建和运行网络所需的所有库。
torch
, torch.nn
, torch.nn.functional
: PyTorch的核心组件,用于定义网络层、张量操作和激活函数。einops
: 一个强大的库,用于简化和标准化复杂的张量操作(如重塑、转置等)。sklearn.decomposition.PCA
: Scikit-learn库中的主成分分析模块,用于后续的特征降维。trunc_normal_
, lecun_normal_
): 这一系列函数定义了网络权重的初始化方法。在深度神经网络中,合适的权重初始化对于防止梯度消失或爆炸、保证训练稳定至关重要。trunc_normal_
从一个截断的正态分布中生成随机数来填充权重张量,确保初始值不会过度偏离均值。class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, *args, **kwargs):
x = self.norm(x)
return self.fn(x, *args, **kwargs)
class GELU(nn.Module):
def forward(self, x):
return F.gelu(x)
def conv(in_channels, out_channels, kernel_size, bias=False, padding=1, stride=1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), bias=bias, stride=stride)
PreNorm
类: 这是一个遵循Transformer设计模式的包装器类。它在执行主要功能模块(fn
)之前,先对输入x
应用层归一化(Layer Normalization)。这种“预归一化”结构有助于稳定深层网络的训练过程。GELU
类: 将F.gelu
(高斯误差线性单元)激活函数封装成一个标准的nn.Module
类。GELU是Transformer模型中常用的激活函数,相比ReLU更为平滑。conv
函数: 一个辅助函数,用于快速创建二维卷积层,并自动计算填充(padding)以保持特征图尺寸不变(当步长为1时)。Illumination_Estimator
类class Illumination_Estimator(nn.Module):
def __init__(
self, n_fea_middle, n_fea_in=4, n_fea_out=3): #__init__部分是内部属性,而forward的输入才是外部输入
super(Illumination_Estimator, self).__init__()
self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)
self.depth_conv = nn.Conv2d(
n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)
self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)
def forward(self, img):
# img: b,c=3,h,w
# mean_c: b,c=1,h,w
# illu_fea: b,c,h,w
# illu_map: b,c=3,h,w
mean_c = img.mean(dim=1).unsqueeze(1)
# stx()
input = torch.cat([img,mean_c], dim=1)
x_1 = self.conv1(input)
illu_fea = self.depth_conv(x_1)
illu_map = self.conv2(illu_fea)
return illu_fea, illu_map
这个类是论文中基于Retinex理论分解步骤的直接代码实现。
__init__
:
depth_conv
是一个深度可分离卷积,善于高效地捕捉空间特征。forward
:
mean_c = img.mean(dim=1).unsqueeze(1)
: 计算输入图像在通道维度的均值,得到一张单通道的亮度图。input = torch.cat([img, mean_c], dim=1)
: 将原始图像和亮度图拼接,为网络提供原始颜色信息和整体亮度信息。illu_fea = self.depth_conv(x_1)
: 经过初步处理后,通过深度卷积提取的特征。这对应论文中提到的光照特征 F _ i l l u m F\_{illum} F_illum [cite: 59, 64]。illu_map = self.conv2(illu_fea)
: 将光照特征通过1x1卷积最终生成光照图。这对应论文中提到的光照图 M M M [cite: 59, 64]。return illu_fea, illu_map
: 模块同时返回抽象的光照特征和具体的光照图,供后续步骤使用 [cite: 64]。IG_MSA
类 (光照引导的多头自注意力)class IG_MSA(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
):
super().__init__()
self.num_heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
self.proj = nn.Linear(dim_head * heads, dim, bias=True)
self.pos_emb = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
GELU(),
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
)
self.dim = dim
def forward(self, x_in, illu_fea_trans):
"""
x_in: [b,h,w,c] # input_feature
illu_fea: [b,h,w,c] # mask shift? 为什么是 b, h, w, c?
return out: [b,h,w,c]
"""
b, h, w, c = x_in.shape
x = x_in.reshape(b, h * w, c)
q_inp = self.to_q(x)
k_inp = self.to_k(x)
v_inp = self.to_v(x)
illu_attn = illu_fea_trans # illu_fea: b,c,h,w -> b,h,w,c
q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
(q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))
v = v * illu_attn
# q: b,heads,hw,c
q = q.transpose(-2, -1)
k = k.transpose(-2, -1)
v = v.transpose(-2, -1)
q = F.normalize(q, dim=-1, p=2)
k = F.normalize(k, dim=-1, p=2)
attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
attn = attn * self.rescale
attn = attn.softmax(dim=-1)
x = attn @ v # b,heads,d,hw
x = x.permute(0, 3, 1, 2) # Transpose
x = x.reshape(b, h * w, self.num_heads * self.dim_head)
out_c = self.proj(x).view(b, h, w, c)
out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
0, 3, 1, 2)).permute(0, 2, 3, 1)
out = out_c + out_p
return out
这是网络的核心创新点——光照引导的注意力机制的代码实现,对应论文的3.2节。
__init__
: 定义了将输入特征投影到查询(Q)、键(K)、值(V)的线性层,以及一个用于学习位置信息的卷积网络pos_emb
。forward
:
q_inp, k_inp, v_inp
: 将输入特征x_in
通过线性层投影得到Q, K, V。illu_attn = illu_fea_trans
: 接收从Illumination_Estimator
传来的光照特征。rearrange(...)
: 使用einops
库重排张量形状,以适配多头注意力的计算。v = v * illu_attn
: 这是最关键的一步。它将V向量与光照特征进行逐元素的乘法。这完全实现了论文中提到的“将光照特征作为可学习的重加权项作用于值V上”的思想,使得网络能根据光照信息动态地强调或抑制某些区域的特征 [cite: 77]。attn = (k @ q.transpose(-2, -1))
: 之后是标准的注意力计算流程,通过K和Q的点积计算注意力分数 [cite: 75, 76]。x = attn @ v
: 将注意力分数应用到被光照引导过的V向量上,得到最终的加权特征 [cite: 76]。IGAB
类 (光照引导的注意力块)class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
GELU(),
nn.Conv2d(dim * mult, dim * mult, 3, 1, 1,
bias=False, groups=dim * mult),
GELU(),
nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
)
def forward(self, x):
"""
x: [b,h,w,c]
return out: [b,h,w,c]
"""
out = self.net(x.permute(0, 3, 1, 2).contiguous())
return out.permute(0, 2, 3, 1)
class IGAB(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
num_blocks=2,
):
super().__init__()
self.blocks = nn.ModuleList([])
for _ in range(num_blocks):
self.blocks.append(nn.ModuleList([
IG_MSA(dim=dim, dim_head=dim_head, heads=heads),
PreNorm(dim, FeedForward(dim=dim))
]))
def forward(self, x, illu_fea):
"""
x: [b,c,h,w]
illu_fea: [b,c,h,w]
return out: [b,c,h,w]
"""
x = x.permute(0, 2, 3, 1)
for (attn, ff) in self.blocks:
x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
x = ff(x) + x
out = x.permute(0, 3, 1, 2)
return out
这个模块将前面的IG_MSA
和FeedForward
组合成一个完整的、可重复使用的光照引导注意力块(IGAB)。
FeedForward
: 一个标准的前馈网络,用于在注意力操作后进行非线性变换。IGAB
__init__
: 创建一个模块列表self.blocks
,其中每个元素都包含一个IG_MSA
注意力模块和一个被PreNorm
包装的FeedForward
前馈网络模块。IGAB
forward
:
for (attn, ff) in self.blocks
: 循环遍历每一个注意力-前馈网络对。x = attn(...) + x
: 执行光照引导的自注意力计算,并将结果与原始输入x
相加。这是一个残差连接,对于构建深度网络、防止梯度消失至关重要。x = ff(x) + x
: 将上一步的结果输入前馈网络,并再次进行残差连接。Denoiser
类 (U-Net 主干网络)class Denoiser(nn.Module):
def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
super(Denoiser, self).__init__()
self.dim = dim
self.level = level
# Input projection
self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
# Encoder
self.encoder_layers = nn.ModuleList([])
dim_level = dim
for i in range(level):
self.encoder_layers.append(nn.ModuleList([
IGAB(
dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
]))
dim_level *= 2
# Bottleneck
self.bottleneck = IGAB(
dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])
# Decoder
self.decoder_layers = nn.ModuleList([])
for i in range(level):
self.decoder_layers.append(nn.ModuleList([
nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
kernel_size=2, padding=0, output_padding=0),
nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
IGAB(
dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
heads=(dim_level // 2) // dim),
]))
dim_level //= 2
# Output projection
self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, illu_fea):
"""
x: [b,c,h,w] x是feature, 不是image
illu_fea: [b,c,h,w]
return out: [b,c,h,w]
"""
# Embedding
fea = self.embedding(x)
# Encoder
fea_encoder = []
illu_fea_list = []
for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
fea = IGAB(fea,illu_fea) # bchw
illu_fea_list.append(illu_fea)
fea_encoder.append(fea)
fea = FeaDownSample(fea)
illu_fea = IlluFeaDownsample(illu_fea)
# Bottleneck
fea = self.bottleneck(fea,illu_fea)
# Decoder
for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
fea = FeaUpSample(fea)
fea = Fution(
torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
illu_fea = illu_fea_list[self.level-1-i]
fea = LeWinBlcok(fea,illu_fea)
# Mapping
out = self.mapping(fea) + x
return out
这个类是网络的主体结构,采用了经典的U-Net架构,负责主要的特征处理和图像重建。
__init__
: 定义了U-Net的全部组件:
encoder_layers
(编码器): 包含IGAB
块用于特征提取,以及步长为2的卷积用于下采样。bottleneck
(瓶颈层): U-Net的最底层,在最低分辨率上进行处理。decoder_layers
(解码器): 包含转置卷积(ConvTranspose2d
)用于上采样,以及IGAB
块用于处理融合后的特征。forward
:
IGAB
处理后下采样。处理前的特征fea
被保存在fea_encoder
列表中。torch.cat([fea, fea_encoder[...]], dim=1)
)。这就是U-Net的“跳跃连接”(Skip Connection),它对于图像重建任务至关重要,能将浅层的细节信息与深层的语义信息结合。out = self.mapping(fea) + x
: 最后,通过一个卷积层将特征映射回图像通道数,并与原始输入x
进行残差连接,输出最终结果。RTxNet_Single_Stage
和 RTxNet
类class RTxNet_Single_Stage(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_feat=31, level=2, num_blocks=[1, 1, 1]):
super(RTxNet_Single_Stage, self).__init__()
self.estimator_rgb = Illumination_Estimator(n_feat)
self.estimator_thermal = Illumination_Estimator(n_feat)
self.cross_attention = IGAB(dim=n_feat, dim_head=n_feat // 4, heads=4, num_blocks=1)
self.denoiser = Denoiser(in_dim=in_channels, out_dim=out_channels, dim=n_feat, level=level, num_blocks=num_blocks)
def forward(self, img):
# img: b,c=3,h,w
# illu_fea: b,c,h,w
# illu_map: b,c=3,h,w
img_rgb, img_thermal = img
# Estimate illumination for RGB image
#img_rgb = img_rgb1.resize(640, 512)
illu_fea_rgb, illu_map_rgb = self.estimator_rgb(img_rgb)
#print("Illu_fea_rgb: ", illu_fea_rgb.shape)
# Estimate illumination for thermal image
illu_fea_thermal, illu_map_thermal = self.estimator_thermal(img_thermal)
#print("Illu_fea_thermal: ", illu_fea_thermal.shape)
#merged_features = torch.cat([illu_fea_rgb, illu_fea_thermal], dim=1)
#print("Merged Features: ", merged_features.shape)
#print("RGB Image: ", img_rgb.shape)
#print("Illu map RGB: ", illu_map_rgb.shape)
enhanced_fea_rgb = self.cross_attention(illu_fea_rgb, illu_fea_rgb)
#print("Illu_fea_rgb: ", enhanced_fea_rgb.shape)
enhanced_fea_thermal = self.cross_attention(illu_fea_thermal, illu_fea_thermal)
#print("Illu_fea_thermal: ", enhanced_fea_thermal.shape)
# Merge enhanced features
merged_features = torch.cat([enhanced_fea_rgb, enhanced_fea_thermal], dim=1)
#print("Merged Features: ", merged_features.shape)
batch_size, channels, height, width = merged_features.size()
# # Reshape to 2D array for PCA
# flattened_features = merged_features.permute(0, 2, 3, 1).contiguous().view(-1, channels)
desired_channels = 40
# # Apply PCA along the channel dimension
# pca = PCA(n_components=desired_channels)
# reduced_features_np = pca.fit_transform(flattened_features.cpu().detach().numpy())
# # Reshape reduced features back to original shape
# reduced_features = torch.from_numpy(reduced_features_np).view(batch_size, height, width, -1).permute(0, 3, 1, 2)
flattened_features = merged_features.view(batch_size, channels, -1)
flattened_features = flattened_features.permute(0, 2, 1).contiguous().view(-1, channels)
# Apply PCA
pca = PCA(n_components=desired_channels)
reduced_features_np = pca.fit_transform(flattened_features.detach().cpu().numpy())
reduced_features = torch.from_numpy(reduced_features_np).view(batch_size, -1, desired_channels).permute(0, 2, 1).contiguous().view(batch_size, desired_channels, height, width)
#print("Reduced Features: ", reduced_features.shape)
reduced_features = reduced_features.cuda()
input_img = img_rgb * illu_map_rgb + img_rgb
output_img = self.denoiser(input_img, reduced_features)
return output_img
class RTxNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_feat=31, stage=3, num_blocks=[1,1,1]):
super(RTxNet, self).__init__()
self.stage = stage
modules_body = [RTxNet_Single_Stage(in_channels=in_channels, out_channels=out_channels, n_feat=n_feat, level=2, num_blocks=num_blocks)
for _ in range(stage)]
self.body = nn.Sequential(*modules_body)
def forward(self, x):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
out = self.body(x)
return out
这部分代码将所有组件组装成最终的、可执行的模型。
RTxNet_Single_Stage
: 定义了单阶段的处理流程,是网络的核心逻辑所在。
__init__
: 实例化了所有需要的模块:两个Illumination_Estimator
(RGB和热成像各一个),一个IGAB
模块,以及一个Denoiser
主干网络。forward
:
illu_fea_rgb, ... = self.estimator_rgb(img_rgb)
: 分别从RGB和热成像图像中提取光照特征和光照图。enhanced_fea_rgb = self.cross_attention(illu_fea_rgb, illu_fea_rgb)
: 这是一个关键的实现细节。虽然模块变量名叫cross_attention
,但这里的调用方式是对每个模态的特征独立进行自注意力(Self-Attention),因为查询、键、值都来自同一个输入源(illu_fea_rgb
或illu_fea_thermal
)。这与论文3.3节描述的“查询来自一个模态,而键和值来自另一个模态”的跨注意力(Cross-Attention)机制存在实现上的差异 [cite: 84]。merged_features = torch.cat(...)
: 将两个模态经过自注意力处理后的特征在通道维度上拼接起来。sklearn.decomposition.PCA
进行主成分分析降维。这是一个非典型的操作,它在网络内部引入了一个基于CPU的、数据驱动的降维步骤,以控制特征维度并提取最关键的混合表征 [cite: 86, 87]。input_img = img_rgb * illu_map_rgb + img_rgb
: 根据Retinex理论,创建Denoiser
的主输入图像,即用光照图初步“照亮”的RGB图像。这对应论文中的“lit-up image” [cite: 67]。output_img = self.denoiser(input_img, reduced_features)
: 调用U-Net主干网络,将初步增强的图像和经过PCA降维的融合特征(作为引导信息)送入,得到最终输出。RTxNet
:
nn.Sequential
将多个RTxNet_Single_Stage
串联起来,构建了一个**多阶段(multi-stage)**的处理架构。这意味着图像可以被反复送入核心网络进行多次迭代优化,从而逐步提升增强效果。