【CVPR 2025】低光增强RT-X Net( 红外辅助结构引导)--part2代码讲解

【CVPR 2025】本文参考论文RT-X NET: RGB-THERMAL CROSS ATTENTION NETWORK FOR LOW-LIGHT IMAGE ENHANCEMENT Transformer Design
下面对方法论部分进行详细分析
论文提出网络类似Retinexformer / SG-LLIE,感兴趣的可以在我的主页进一步学习
论文地址:arxiv
代码地址:github

文章目录

      • **代码段 1: 导入与辅助函数**
        • **分析:**
      • **代码段 2: 基础神经网络模块**
        • **分析:**
      • **代码段 3: `Illumination_Estimator` 类**
        • **分析:**
      • **代码段 4: `IG_MSA` 类 (光照引导的多头自注意力)**
        • **分析:**
      • **代码段 5: `IGAB` 类 (光照引导的注意力块)**
        • **分析:**
      • **代码段 6: `Denoiser` 类 (U-Net 主干网络)**
        • **分析:**
      • **代码段 7: `RTxNet_Single_Stage` 和 `RTxNet` 类**
        • **分析:**


【CVPR 2025】低光增强RT-X Net( 红外辅助结构引导)--part2代码讲解_第1张图片

代码段 1: 导入与辅助函数

import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
import math
import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
from pdb import set_trace as stx
from sklearn.decomposition import PCA

# import cv2


# def apply_pca(features, desired_channels):
#     batch_size, height, width, channels = features.shape

#     # Flatten the features and reshape to 2D array for PCA

#     flattened_features = features.permute(0, 2, 3, 1).contiguous().view(-1, channels)

#     # Convert to numpy and perform PCA
#     pca = PCA(n_components=desired_channels)
#     reduced_features_np = pca.fit_transform(flattened_features.detach().cpu().numpy())
#     print("Reduced Features after applying PCA: ", reduced_features_np.shape)

#     # Convert back to tensor and reshape
#     reduced_features = torch.from_numpy(reduced_features_np).to(features.device)
#     print("Reduced Features after convertng to numpy: ", reduced_features.shape)
    
#     reduced_features = reduced_features.view(batch_size, height, width, desired_channels).permute(0, 3, 1, 2)
#     print("Final Reduced Features: ", reduced_features.shape)

#     return reduced_features

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)
    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    if mode == 'fan_in':
        denom = fan_in
    elif mode == 'fan_out':
        denom = fan_out
    elif mode == 'fan_avg':
        denom = (fan_in + fan_out) / 2
    variance = scale / denom
    if distribution == "truncated_normal":
        trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
    elif distribution == "normal":
        tensor.normal_(std=math.sqrt(variance))
    elif distribution == "uniform":
        bound = math.sqrt(3 * variance)
        tensor.uniform_(-bound, bound)
    else:
        raise ValueError(f"invalid distribution {distribution}")


def lecun_normal_(tensor):
    variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
分析:
  • import 语句: 导入了构建和运行网络所需的所有库。
    • torch, torch.nn, torch.nn.functional: PyTorch的核心组件,用于定义网络层、张量操作和激活函数。
    • einops: 一个强大的库,用于简化和标准化复杂的张量操作(如重塑、转置等)。
    • sklearn.decomposition.PCA: Scikit-learn库中的主成分分析模块,用于后续的特征降维。
  • 权重初始化函数 (trunc_normal_, lecun_normal_): 这一系列函数定义了网络权重的初始化方法。在深度神经网络中,合适的权重初始化对于防止梯度消失或爆炸、保证训练稳定至关重要。trunc_normal_ 从一个截断的正态分布中生成随机数来填充权重张量,确保初始值不会过度偏离均值。

代码段 2: 基础神经网络模块

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.fn(x, *args, **kwargs)


class GELU(nn.Module):
    def forward(self, x):
        return F.gelu(x)


def conv(in_channels, out_channels, kernel_size, bias=False, padding=1, stride=1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2), bias=bias, stride=stride)
分析:
  • PreNorm: 这是一个遵循Transformer设计模式的包装器类。它在执行主要功能模块(fn之前,先对输入x应用层归一化(Layer Normalization)。这种“预归一化”结构有助于稳定深层网络的训练过程。
  • GELU: 将F.gelu(高斯误差线性单元)激活函数封装成一个标准的nn.Module类。GELU是Transformer模型中常用的激活函数,相比ReLU更为平滑。
  • conv 函数: 一个辅助函数,用于快速创建二维卷积层,并自动计算填充(padding)以保持特征图尺寸不变(当步长为1时)。

代码段 3: Illumination_Estimator

class Illumination_Estimator(nn.Module):
    def __init__(
            self, n_fea_middle, n_fea_in=4, n_fea_out=3):  #__init__部分是内部属性,而forward的输入才是外部输入
        super(Illumination_Estimator, self).__init__()

        self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)

        self.depth_conv = nn.Conv2d(
            n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)

        self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)

    def forward(self, img):
        # img:        b,c=3,h,w
        # mean_c:     b,c=1,h,w
        
        # illu_fea:   b,c,h,w
        # illu_map:   b,c=3,h,w
        
        mean_c = img.mean(dim=1).unsqueeze(1)
        # stx()
        input = torch.cat([img,mean_c], dim=1)

        x_1 = self.conv1(input)
        illu_fea = self.depth_conv(x_1)
        illu_map = self.conv2(illu_fea)
        return illu_fea, illu_map
分析:

这个类是论文中基于Retinex理论分解步骤的直接代码实现。

  • __init__:
    • 定义了三个卷积层,与一个简单的编码器-解码器结构类似。
    • depth_conv是一个深度可分离卷积,善于高效地捕捉空间特征。
  • forward:
    1. mean_c = img.mean(dim=1).unsqueeze(1): 计算输入图像在通道维度的均值,得到一张单通道的亮度图。
    2. input = torch.cat([img, mean_c], dim=1): 将原始图像和亮度图拼接,为网络提供原始颜色信息和整体亮度信息。
    3. [cite_start]illu_fea = self.depth_conv(x_1): 经过初步处理后,通过深度卷积提取的特征。这对应论文中提到的光照特征 F _ i l l u m F\_{illum} F_illum [cite: 59, 64]。
    4. [cite_start]illu_map = self.conv2(illu_fea): 将光照特征通过1x1卷积最终生成光照图。这对应论文中提到的光照图 M M M [cite: 59, 64]。
    5. [cite_start]return illu_fea, illu_map: 模块同时返回抽象的光照特征和具体的光照图,供后续步骤使用 [cite: 64]。

代码段 4: IG_MSA 类 (光照引导的多头自注意力)

class IG_MSA(nn.Module):
    def __init__(
            self,
            dim,
            dim_head=64,
            heads=8,
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
        self.proj = nn.Linear(dim_head * heads, dim, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
            GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
        )
        self.dim = dim

    def forward(self, x_in, illu_fea_trans):
        """
        x_in: [b,h,w,c]         # input_feature
        illu_fea: [b,h,w,c]         # mask shift? 为什么是 b, h, w, c?
        return out: [b,h,w,c]
        """
        b, h, w, c = x_in.shape
        x = x_in.reshape(b, h * w, c)
        q_inp = self.to_q(x)
        k_inp = self.to_k(x)
        v_inp = self.to_v(x)
        illu_attn = illu_fea_trans # illu_fea: b,c,h,w -> b,h,w,c
        q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
                                 (q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))
        
        
        v = v  * illu_attn
        # q: b,heads,hw,c
        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)
        attn = (k @ q.transpose(-2, -1))   # A = K^T*Q
        attn = attn * self.rescale
        attn = attn.softmax(dim=-1)
        x = attn @ v   # b,heads,d,hw
        x = x.permute(0, 3, 1, 2)    # Transpose
        x = x.reshape(b, h * w, self.num_heads * self.dim_head)
        out_c = self.proj(x).view(b, h, w, c)
        out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
            0, 3, 1, 2)).permute(0, 2, 3, 1)
        out = out_c + out_p

        return out
分析:

这是网络的核心创新点——光照引导的注意力机制的代码实现,对应论文的3.2节。

  • __init__: 定义了将输入特征投影到查询(Q)、键(K)、值(V)的线性层,以及一个用于学习位置信息的卷积网络pos_emb
  • forward:
    1. q_inp, k_inp, v_inp: 将输入特征x_in通过线性层投影得到Q, K, V。
    2. illu_attn = illu_fea_trans: 接收从Illumination_Estimator传来的光照特征。
    3. rearrange(...): 使用einops库重排张量形状,以适配多头注意力的计算。
    4. [cite_start]v = v * illu_attn: 这是最关键的一步。它将V向量与光照特征进行逐元素的乘法。这完全实现了论文中提到的“将光照特征作为可学习的重加权项作用于值V上”的思想,使得网络能根据光照信息动态地强调或抑制某些区域的特征 [cite: 77]。
    5. [cite_start]attn = (k @ q.transpose(-2, -1)): 之后是标准的注意力计算流程,通过K和Q的点积计算注意力分数 [cite: 75, 76]。
    6. [cite_start]x = attn @ v: 将注意力分数应用到被光照引导过的V向量上,得到最终的加权特征 [cite: 76]。

代码段 5: IGAB 类 (光照引导的注意力块)

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
            GELU(),
            nn.Conv2d(dim * mult, dim * mult, 3, 1, 1,
                      bias=False, groups=dim * mult),
            GELU(),
            nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
        )

    def forward(self, x):
        """
        x: [b,h,w,c]
        return out: [b,h,w,c]
        """
        out = self.net(x.permute(0, 3, 1, 2).contiguous())
        return out.permute(0, 2, 3, 1)


class IGAB(nn.Module):
    def __init__(
            self,
            dim,
            dim_head=64,
            heads=8,
            num_blocks=2,
    ):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(nn.ModuleList([
                IG_MSA(dim=dim, dim_head=dim_head, heads=heads),
                PreNorm(dim, FeedForward(dim=dim))
            ]))

    def forward(self, x, illu_fea):
        """
        x: [b,c,h,w]
        illu_fea: [b,c,h,w]
        return out: [b,c,h,w]
        """
        x = x.permute(0, 2, 3, 1)
        for (attn, ff) in self.blocks:
            x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
            x = ff(x) + x
        out = x.permute(0, 3, 1, 2)
        return out
分析:

这个模块将前面的IG_MSAFeedForward组合成一个完整的、可重复使用的光照引导注意力块(IGAB)

  • FeedForward: 一个标准的前馈网络,用于在注意力操作后进行非线性变换。
  • IGAB __init__: 创建一个模块列表self.blocks,其中每个元素都包含一个IG_MSA注意力模块和一个被PreNorm包装的FeedForward前馈网络模块。
  • IGAB forward:
    1. for (attn, ff) in self.blocks: 循环遍历每一个注意力-前馈网络对。
    2. x = attn(...) + x: 执行光照引导的自注意力计算,并将结果与原始输入x相加。这是一个残差连接,对于构建深度网络、防止梯度消失至关重要。
    3. x = ff(x) + x: 将上一步的结果输入前馈网络,并再次进行残差连接。

代码段 6: Denoiser 类 (U-Net 主干网络)

class Denoiser(nn.Module):
    def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
        super(Denoiser, self).__init__()
        self.dim = dim
        self.level = level

        # Input projection
        self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)

        # Encoder
        self.encoder_layers = nn.ModuleList([])
        dim_level = dim
        for i in range(level):
            self.encoder_layers.append(nn.ModuleList([
                IGAB(
                    dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
            ]))
            dim_level *= 2

        # Bottleneck
        self.bottleneck = IGAB(
            dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])

        # Decoder
        self.decoder_layers = nn.ModuleList([])
        for i in range(level):
            self.decoder_layers.append(nn.ModuleList([
                nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
                                   kernel_size=2, padding=0, output_padding=0),
                nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
                IGAB(
                    dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
                    heads=(dim_level // 2) // dim),
            ]))
            dim_level //= 2

        # Output projection
        self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, illu_fea):
        """
        x:          [b,c,h,w]         x是feature, 不是image
        illu_fea:   [b,c,h,w]
        return out: [b,c,h,w]
        """

        # Embedding
        fea = self.embedding(x)

        # Encoder
        fea_encoder = []
        illu_fea_list = []
        for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
            fea = IGAB(fea,illu_fea)  # bchw
            illu_fea_list.append(illu_fea)
            fea_encoder.append(fea)
            fea = FeaDownSample(fea)
            illu_fea = IlluFeaDownsample(illu_fea)

        # Bottleneck
        fea = self.bottleneck(fea,illu_fea)

        # Decoder
        for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
            fea = FeaUpSample(fea)
            fea = Fution(
                torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
            illu_fea = illu_fea_list[self.level-1-i]
            fea = LeWinBlcok(fea,illu_fea)

        # Mapping
        out = self.mapping(fea) + x

        return out
分析:

这个类是网络的主体结构,采用了经典的U-Net架构,负责主要的特征处理和图像重建。

  • __init__: 定义了U-Net的全部组件:
    • encoder_layers (编码器): 包含IGAB块用于特征提取,以及步长为2的卷积用于下采样。
    • bottleneck (瓶颈层): U-Net的最底层,在最低分辨率上进行处理。
    • decoder_layers (解码器): 包含转置卷积(ConvTranspose2d)用于上采样,以及IGAB块用于处理融合后的特征。
  • forward:
    1. Encoder: 数据向下传播,在每一层,特征图经过IGAB处理后下采样。处理前的特征fea被保存在fea_encoder列表中。
    2. Decoder: 数据向上回传,在每一层,特征图先被上采样,然后与编码器对应层保存的特征进行拼接(torch.cat([fea, fea_encoder[...]], dim=1))。这就是U-Net的“跳跃连接”(Skip Connection),它对于图像重建任务至关重要,能将浅层的细节信息与深层的语义信息结合。
    3. out = self.mapping(fea) + x: 最后,通过一个卷积层将特征映射回图像通道数,并与原始输入x进行残差连接,输出最终结果。

代码段 7: RTxNet_Single_StageRTxNet

class RTxNet_Single_Stage(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_feat=31, level=2, num_blocks=[1, 1, 1]):
        super(RTxNet_Single_Stage, self).__init__()
        self.estimator_rgb = Illumination_Estimator(n_feat)
        self.estimator_thermal = Illumination_Estimator(n_feat)
        self.cross_attention = IGAB(dim=n_feat, dim_head=n_feat // 4, heads=4, num_blocks=1)
        self.denoiser = Denoiser(in_dim=in_channels, out_dim=out_channels, dim=n_feat, level=level, num_blocks=num_blocks)
    
    def forward(self, img):
        # img:        b,c=3,h,w
        
        # illu_fea:   b,c,h,w
        # illu_map:   b,c=3,h,w
        img_rgb, img_thermal = img
        # Estimate illumination for RGB image
        #img_rgb = img_rgb1.resize(640, 512)
        illu_fea_rgb, illu_map_rgb = self.estimator_rgb(img_rgb)
        #print("Illu_fea_rgb: ", illu_fea_rgb.shape)
        # Estimate illumination for thermal image
        illu_fea_thermal, illu_map_thermal = self.estimator_thermal(img_thermal)
        #print("Illu_fea_thermal: ", illu_fea_thermal.shape)


        #merged_features = torch.cat([illu_fea_rgb, illu_fea_thermal], dim=1)
        #print("Merged Features: ", merged_features.shape)
        #print("RGB Image: ", img_rgb.shape)
        #print("Illu map RGB: ", illu_map_rgb.shape)

        enhanced_fea_rgb = self.cross_attention(illu_fea_rgb, illu_fea_rgb)
        #print("Illu_fea_rgb: ", enhanced_fea_rgb.shape)

        enhanced_fea_thermal = self.cross_attention(illu_fea_thermal, illu_fea_thermal)
        #print("Illu_fea_thermal: ", enhanced_fea_thermal.shape)


        # Merge enhanced features
        merged_features = torch.cat([enhanced_fea_rgb, enhanced_fea_thermal], dim=1)
        #print("Merged Features: ", merged_features.shape)

        batch_size, channels, height, width = merged_features.size()

        # # Reshape to 2D array for PCA
        # flattened_features = merged_features.permute(0, 2, 3, 1).contiguous().view(-1, channels)
        desired_channels  = 40
        # # Apply PCA along the channel dimension
        # pca = PCA(n_components=desired_channels)
        # reduced_features_np = pca.fit_transform(flattened_features.cpu().detach().numpy())

        # # Reshape reduced features back to original shape
        # reduced_features = torch.from_numpy(reduced_features_np).view(batch_size, height, width, -1).permute(0, 3, 1, 2)
        flattened_features = merged_features.view(batch_size, channels, -1)
        flattened_features = flattened_features.permute(0, 2, 1).contiguous().view(-1, channels)

        # Apply PCA
        pca = PCA(n_components=desired_channels)
        reduced_features_np = pca.fit_transform(flattened_features.detach().cpu().numpy())
        reduced_features = torch.from_numpy(reduced_features_np).view(batch_size, -1, desired_channels).permute(0, 2, 1).contiguous().view(batch_size, desired_channels, height, width)
        
        #print("Reduced Features: ", reduced_features.shape)
        reduced_features = reduced_features.cuda()

        input_img = img_rgb * illu_map_rgb + img_rgb
        output_img = self.denoiser(input_img, reduced_features)

        return output_img


class RTxNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_feat=31, stage=3, num_blocks=[1,1,1]):
        super(RTxNet, self).__init__()
        self.stage = stage

        modules_body = [RTxNet_Single_Stage(in_channels=in_channels, out_channels=out_channels, n_feat=n_feat, level=2, num_blocks=num_blocks)
                        for _ in range(stage)]
        
        self.body = nn.Sequential(*modules_body)
    
    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        out = self.body(x)

        return out
分析:

这部分代码将所有组件组装成最终的、可执行的模型。

  • RTxNet_Single_Stage: 定义了单阶段的处理流程,是网络的核心逻辑所在。
    • __init__: 实例化了所有需要的模块:两个Illumination_Estimator(RGB和热成像各一个),一个IGAB模块,以及一个Denoiser主干网络。
    • forward:
      1. illu_fea_rgb, ... = self.estimator_rgb(img_rgb): 分别从RGB和热成像图像中提取光照特征和光照图。
      2. [cite_start]enhanced_fea_rgb = self.cross_attention(illu_fea_rgb, illu_fea_rgb): 这是一个关键的实现细节。虽然模块变量名叫cross_attention,但这里的调用方式是对每个模态的特征独立进行自注意力(Self-Attention),因为查询、键、值都来自同一个输入源(illu_fea_rgbillu_fea_thermal)。这与论文3.3节描述的“查询来自一个模态,而键和值来自另一个模态”的跨注意力(Cross-Attention)机制存在实现上的差异 [cite: 84]。
      3. merged_features = torch.cat(...): 将两个模态经过自注意力处理后的特征在通道维度上拼接起来。
      4. [cite_start]PCA: 将拼接后的高维特征通过sklearn.decomposition.PCA进行主成分分析降维。这是一个非典型的操作,它在网络内部引入了一个基于CPU的、数据驱动的降维步骤,以控制特征维度并提取最关键的混合表征 [cite: 86, 87]。
      5. [cite_start]input_img = img_rgb * illu_map_rgb + img_rgb: 根据Retinex理论,创建Denoiser的主输入图像,即用光照图初步“照亮”的RGB图像。这对应论文中的“lit-up image” [cite: 67]。
      6. output_img = self.denoiser(input_img, reduced_features): 调用U-Net主干网络,将初步增强的图像和经过PCA降维的融合特征(作为引导信息)送入,得到最终输出。
  • RTxNet:
    • 这是最顶层的模型封装。它通过nn.Sequential将多个RTxNet_Single_Stage串联起来,构建了一个**多阶段(multi-stage)**的处理架构。这意味着图像可以被反复送入核心网络进行多次迭代优化,从而逐步提升增强效果。

你可能感兴趣的:(低光增强LLIE,transformer,深度学习,人工智能,神经网络,计算机视觉)