基于Swin_Transformer的图像超分辨率系统

1.研究背景与意义

项目参考AAAI Association for the Advancement of Artificial Intelligence

研究背景与意义

随着科技的不断发展,图像超分辨率技术在计算机视觉领域中变得越来越重要。图像超分辨率是指通过使用计算机算法将低分辨率图像转换为高分辨率图像的过程。这项技术在许多领域都有广泛的应用,包括医学图像处理、监控摄像头、卫星图像处理等。

在过去的几十年里,图像超分辨率技术已经取得了显著的进展。早期的方法主要基于插值和滤波技术,但这些方法无法捕捉到图像中的细节和纹理。随着深度学习的兴起,基于深度学习的图像超分辨率方法开始受到关注。其中,基于卷积神经网络(CNN)的方法取得了很大的成功。

然而,传统的CNN方法在处理大尺寸图像时存在一些问题。首先,它们需要大量的计算资源和存储空间,限制了它们在实际应用中的可行性。其次,它们往往无法处理大尺寸图像中的细节和纹理,导致生成的高分辨率图像质量不佳。因此,寻找一种高效且准确的图像超分辨率方法是非常重要的。

近年来,Swin Transformer作为一种新兴的注意力机制模型,已经在自然语言处理和计算机视觉领域取得了显著的成果。Swin Transformer采用了一种分层的注意力机制,能够在处理大尺寸图像时保持较高的效率和准确性。因此,将Swin Transformer应用于图像超分辨率任务是非常有前景的研究方向。

基于Swin Transformer的图像超分辨率系统具有以下几个重要的意义:

首先,基于Swin Transformer的图像超分辨率系统可以提供更高质量的高分辨率图像。Swin Transformer的注意力机制能够更好地捕捉到图像中的细节和纹理,从而生成更加真实和清晰的图像。这对于许多应用领域,如医学图像处理和卫星图像处理,具有重要的意义。

其次,基于Swin Transformer的图像超分辨率系统可以提高计算效率。传统的CNN方法在处理大尺寸图像时需要大量的计算资源和存储空间,限制了它们在实际应用中的可行性。而Swin Transformer采用了一种分层的注意力机制,能够在处理大尺寸图像时保持较高的效率和准确性,从而降低了计算成本。

最后,基于Swin Transformer的图像超分辨率系统可以为其他相关领域的研究提供借鉴和参考。Swin Transformer作为一种新兴的注意力机制模型,已经在自然语言处理和计算机视觉领域取得了显著的成果。将其应用于图像超分辨率任务可以为其他领域的研究提供新的思路和方法。

综上所述,基于Swin Transformer的图像超分辨率系统具有重要的研究背景和意义。它可以提供更高质量的高分辨率图像,提高计算效率,并为其他相关领域的研究提供借鉴和参考。随着深度学习和注意力机制的不断发展,相信基于Swin Transformer的图像超分辨率系统将在未来取得更加广泛的应用和研究进展。

2.图片演示



3.视频演示

基于Swin_Transformer的图像超分辨率系统_哔哩哔哩_bilibili

4.图像超分辨重建原理

为了对图像超分辨率重建原理有更深入的理解,本小节将对高分辨率图像到低分辨率的退化过程进行详细介绍。受硬件设备的限制、环境因素的干扰和传输条件的限制,人们采集所得的实际图像的分辨率往往很难达到预期,而这些低分辨率图像通常由高分辨率图像经过多种退化过程所产生,包括光线干扰、运动模糊、噪声、压缩等退化因素。由于图像超分辨率重建是一个典型的逆向问题,其核心概念是建立对应的退化模型来学习从高分辨率图像到低分辨率图像的退化关系,进一步恢复低分辨率图像的纹理细节,因此,建立合适的退化模型是解决超分辨率重建问题的关键。
基于Swin_Transformer的图像超分辨率系统_第1张图片

通过对上述图像退化模型的分析,可知通常的图像退化过程可以描述为高分辨率图像历经一系列的退化因素的影响,产生模糊甚至失真的低分辨率图像。假设x为原始高分辨率图像,J为退化后的低分辨率图像,则图像退化模型可以表示为:
在这里插入图片描述

其中,H()表示整个退化过程,D()表示下采样操作,B()表示模糊操作,L()表示光线干扰,n表示随机噪声,一般为高斯噪声或泊松噪声。图像超分辨率重建的本质就是从岁反向求解x的过程,如式所示:
在这里插入图片描述

即构造相应的图像恢复函数H-(-),对低分辨率图像y进行逆向推算,去除由于软硬件技术、环境因素和人为因素所带来的模糊、下采样和噪声等影响,尽可能恢复出原始高分辨率图像x。

5.核心代码讲解

5.1 main_test_swin2sr.py

class Swin2SR:
    def __init__(self, args):
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.define_model()
        self.model.eval()
        self.model = self.model.to(self.device)
        self.test_results = OrderedDict()
        self.test_results['psnr'] = []
        self.test_results['ssim'] = []
        self.test_results['psnr_y'] = []
        self.test_results['ssim_y'] = []
        self.test_results['psnrb'] = []
        self.test_results['psnrb_y'] = []
        self.psnr, self.ssim, self.psnr_y, self.ssim_y, self.psnrb, self.psnrb_y = 0, 0, 0, 0, 0, 0

    def define_model(self):
        # 001 classical image sr
        if self.args.task == 'classical_sr':
            model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': self.args.scale})
        # 002 lightweight image sr
        elif self.args.task == 'lightweight_sr':
            model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': self.args.scale})
        # 003 real image sr
        elif self.args.task == 'real_sr':
            if self.args.large_model:
                model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                            num_classes=3, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24],
                            mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': self.args.scale})
            else:
                model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                            num_classes=3, embed_dim=48, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24],
                            mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': self.args.scale})
        # 004 grayscale denoising
        elif self.args.task == 'gray_dn':
            model = net(upscale=1, in_chans=1, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=1, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': 1})
        # 005 color denoising
        elif self.args.task == 'color_dn':
            model = net(upscale=1, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': 1})
        # 006 jpeg compression artifact reduction
        elif self.args.task == 'jpeg_car':
            model = net(upscale=1, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': 1})
        # 007 color jpeg compression artifact reduction
        elif self.args.task == 'color_jpeg_car':
            model = net(upscale=1, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        num_classes=3, embed_dim=48, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                        mlp_ratio=4, upsampler='pixelshuffle', upsampler_params={'scale': 1})
        else:
            raise NotImplementedError(f'Task [{self.args.task}] is not implemented.')

        return model

    def setup(self):
        folder, save_dir, border, window_size = self.args.folder_lq, './outputs/', 0, self.args.training_patch_size
        return folder, save_dir, border, window_size

    def get_image_pair(self, path):
        imgname = os.path.splitext(os.path.basename(path))[0]
        img_lq = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        img_gt = None
        if self.args.folder_gt is not None:
            img_gt = cv2.imread(os.path.join(self.args.folder_gt, f'{imgname}.png'), cv2.IMREAD_UNCHANGED)
        return imgname, img_lq, img_gt

    def test(self, img_lq, model, args, window_size):
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        output = model(img_lq)
        if args.task == 'compressed_sr':
            output = output[0][..., :h_old * args.scale, :w_old * args.scale]
        else:
            output = output[..., :h_old * args.scale, :w_old * args.scale]
        return output

    def evaluate(self, output, img_gt, border):
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        if output.ndim == 3:
            output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
        cv2.imwrite(f'{save_dir}/{imgname}_Swin2SR.png', output)

        if img_gt is not None:
            img_gt = (img_gt * 255.0).round().astype(np.uint8)  # float32 to uint8
            img_gt = img_gt[:h_old * args.scale, :w_old * args.scale, ...]  # crop gt
            img_gt = np.squeeze(img_gt)

            psnr = util.calculate_psnr(output, img_gt, crop_border=border)
            ssim = util.calculate_ssim(output, img_gt, crop_border=border)
            self.test_results['psnr'].append(psnr)
            self.test_results['ssim'].append(ssim)
            if img_gt.ndim == 3:  # RGB image
                psnr_y = util.calculate_psnr(output, img_gt, crop_border=border, test_y_channel=True)
                ssim_y = util.calculate_ssim(output, img_gt, crop_border=border, test_y_channel=True)
                self.test_results['psnr_y'].append(psnr_y)
                self.test_results['ssim_y'].append(ssim_y)
            if args.task in ['jpeg_car', 'color_jpeg_car']:
                psnrb = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=False)
                self.test_results['psnrb'].append(psnrb)
                if args.task in ['color_jpeg_car']:
                    psnrb_y = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=True)
                    self.test_results['psnrb_y'].append(psnrb_y)
            print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNRB: {:.2f} dB;'

该程序文件是一个用于图像超分辨率重建的测试程序。程序首先通过命令行参数解析器解析输入参数,包括任务类型、尺度因子、噪声水平、JPEG压缩因子等。然后加载模型并设置设备。接下来,程序设置文件夹和路径,并创建一个用于保存结果的文件夹。然后,程序遍历输入文件夹中的所有图像,读取图像并进行预处理。然后,程序使用模型对图像进行推理,并将结果保存为图像文件。最后,程序计算并打印出PSNR和SSIM等评估指标的平均值。

该程序文件依赖于其他模块和函数,包括argparse、cv2、glob、numpy、collections、os、torch、requests等。其中,models.network_swin2sr模块定义了Swin2SR模型,utils模块包含了计算PSNR和SSIM的函数。

总体而言,该程序文件实现了图像超分辨率重建的测试功能,包括加载模型、预处理图像、进行推理、保存结果和计算评估指标等步骤。

5.2 predict.py

class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""
        print("Loading pipeline...")

        self.device = "cuda:0"

        args = argparse.Namespace()
        args.scale = 4
        args.large_model = False

        tasks = ["classical_sr", "compressed_sr", "real_sr"]
        paths = [
            "weights/Swin2SR_ClassicalSR_X4_64.pth",
            "weights/Swin2SR_CompressedSR_X4_48.pth",
            "weights/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth",
        ]
        sizes = [64, 48, 128]

        self.models = {}
        for task, path, size in zip(tasks, paths, sizes):
            args.training_patch_size = size
            args.task, args.model_path = task, path
            self.models[task] = define_model(args)
            self.models[task].eval()
            self.models[task] = self.models[task].to(self.device)

    def predict(
        self,
        image: Path = Input(description="Input image"),
        task: str = Input(
            description="Choose a task",
            choices=["classical_sr", "real_sr", "compressed_sr"],
            default="real_sr",
        ),
    ) -> Path:
        """Run a single prediction on the model"""

        model = self.models[task]

        window_size = 8
        scale = 4

        img_lq = cv2.imread(str(image), cv2.IMREAD_COLOR).astype(np.float32) / 255.0
        img_lq = np.transpose(
            img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)
        )  # HCW-BGR to CHW-RGB
        img_lq = (
            torch.from_numpy(img_lq).float().unsqueeze(0).to(self.device)
        )  # CHW-RGB to NCHW-RGB

        # inference
        with torch.no_grad():
            # pad input image to be a multiple of window_size
            _, _, h_old, w_old = img_lq.size()
            h_pad = (h_old // window_size + 1) * window_size - h_old
            w_pad = (w_old // window_size + 1) * window_size - w_old
            img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[
                :, :, : h_old + h_pad, :
            ]
            img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[
                :, :, :, : w_old + w_pad
            ]

            output = model(img_lq)

            if task == "compressed_sr":
                output = output[0][..., : h_old * scale, : w_old * scale]
            else:
                output = output[..., : h_old * scale, : w_old * scale]

        # save image
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        if output.ndim == 3:
            output = np.transpose(
                output[[2, 1, 0], :, :], (1, 2, 0)
            )  # CHW-RGB to HCW-BGR
        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
        output_path = "/tmp/out.png"
        cv2.imwrite(output_path, output)

        return Path(output_path)

这个程序文件是一个用于图像超分辨率预测的预测器。它使用了Swin2SR模型来进行预测。文件中定义了一个名为Predictor的类,继承自BasePredictor类。在setup方法中,加载了模型并将其放入内存中以提高多次预测的效率。在predict方法中,通过传入一个输入图像和一个任务类型,可以运行单个预测。预测过程中,首先将输入图像进行预处理,然后使用模型进行推理,最后将输出图像保存到指定路径并返回。

5.3 ui.py


class Swin2SR:
    def __init__(self, args):
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.define_model()
        self.model.eval()
        self.model = self.model.to(self.device)
    
    def define_model(self):
        if self.args.task == 'classical_sr':
            model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                        mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv')
            param_key_g = 'params'
        elif self.args.task in ['lightweight_sr']:
            model = net(upscale=self.args.scale, in_chans=3, img_size=64, window_size=8,
                        img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
                        mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
            param_key_g = 'params'
        elif self.args.task == 'compressed_sr':
            model = net(upscale=self.args.scale, in_chans=3, img_size=self.args.training_patch_size, window_size=8,
                        img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                        mlp_ratio=2, upsampler='pixelshuffle_aux', resi_connection='1conv')
            param_key_g = 'params'
        elif self.args.task == 'real_sr':
            if not self.args.large_model:
                model = net(upscale=self.args.scale, in_chans=3, img_size=64, window_size=8,
                            img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                            mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
            else:
                model = net(upscale=self.args.scale, in_chans=3, img_size=64, window_size=8,
                            img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
                            num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
                            mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
            param_key_g = 'params_ema'
        elif self.args.task == 'jpeg_car':
            model = net(upscale=1, in_chans=1, img_size=126, window_size=7,
                        img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                        mlp_ratio=2, upsampler='', resi_connection='1conv')
            param_key_g = 'params'
        elif self.args.task == 'color_jpeg_car':
            model = net(upscale=1, in_chans=3, img_size=126, window_size=7,
                        img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                        mlp_ratio=2, upsampler='', resi_connection='1conv')
            param_key_g = 'params'
        pretrained_model = torch.load(self.args.model_path)
        model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model,
                              strict=True)
        return model
    
    def setup(self):
        if self.args.task in ['classical_sr', 'lightweight_sr', 'compressed_sr']:
            save_dir = f'results/swin2sr_{self.args.task}_x{self.args.scale}'
            if self.args.save_img_only:
                folder = self.args.folder_lq
            else:
                folder = self.args.folder_gt
            border = self.args.scale
            window_size = 8
       

ui.py是一个用于图像超分辨率的PyQt5界面程序。它导入了PyQt5和其他一些必要的库,并定义了一些函数来加载模型、设置参数、获取图像对和进行测试。主要的函数是main()函数,它接受一个图像路径作为输入,并根据指定的参数加载模型并对图像进行超分辨率处理。处理结果将保存在指定的文件夹中。

5.4 models\network_swin2sr.py
class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        pretrained_window_size (int): Window size in pre-training.
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
            pretrained_window_size=to_2tuple(pretrained_window_size))

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def calculate_mask(self, x_size):
        # calculate attention mask for SW-MSA
        H, W = x_size
        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, N, C).
        """
        B, N, C = x.shape
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, N, C)
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        x = self.attn(x, mask=self.attn_mask)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

这是一个实现Swin Transformer模型的Python程序文件。Swin Transformer是一种用于压缩图像超分辨率和恢复的模型,具体细节可以参考论文https://arxiv.org/abs/2209.11345。

程序文件中定义了一些辅助函数和模块,包括Mlp、window_partition、window_reverse、WindowAttention和SwinTransformerBlock。

Mlp是一个多层感知机模块,用于对输入进行线性变换和激活函数处理。

window_partition和window_reverse函数用于将输入图像划分为窗口,并将窗口恢复为原始图像。

WindowAttention是一个基于窗口的多头自注意力模块,支持相对位置偏置。

SwinTransformerBlock是Swin Transformer的一个基本模块,包括窗口注意力和多层感知机。

整个程序文件实现了Swin Transformer模型的核心组件,可以用于图像超分辨率和恢复任务。

5.5 utils\plots.py


class ImageLoader:
    def __init__(self, debug=False, norm=True, resize=None):
        self.debug = debug
        self.norm = norm
        self.resize = resize
    
    def load_img(self, filename):
        img = cv2.imread(filename)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.norm:   
            img = img / 255.
            img = img.astype(np.float32)
        if self.debug:
            print (img.shape, img.dtype, img.min(), img.max())
            
        if self.resize:
            img = cv2.resize(img, (self.resize[0], self.resize[1]))
            
        return img
    
    def plot_all(self, images, axis='off', figsize=(16, 8)):
        fig = plt.figure(figsize=figsize, dpi=80)
        nplots = len(images)
        for i in range(nplots):
            plt.subplot(1,nplots,i+1)
            plt.axis(axis)
            plt.imshow(images[i])
        plt.show()

这个程序文件是一个用于绘制图像的工具文件,文件名为utils\plots.py。该文件包含了两个函数load_img和plot_all。

load_img函数用于加载图像文件。它接受一个文件名作为参数,并可选择是否进行调试、归一化和调整大小。函数首先使用OpenCV库的imread函数读取图像文件,然后将图像从BGR颜色空间转换为RGB颜色空间。如果选择进行归一化,则将图像的像素值除以255,并将其转换为32位浮点数类型。如果选择进行调试,则会打印图像的形状、数据类型、最小值和最大值。如果选择调整大小,则会使用OpenCV库的resize函数将图像调整为指定的大小。最后,函数返回加载和处理后的图像。

plot_all函数用于绘制多个图像。它接受一个图像列表作为参数,并可选择绘制轴的样式和图像的大小。函数首先创建一个matplotlib的Figure对象,并设置其大小和分辨率。然后,根据图像列表的长度,在Figure对象中创建相应数量的子图。对于每个子图,设置轴的样式,并使用imshow函数显示对应的图像。最后,调用show函数显示绘制的图像。

这个程序文件提供了方便的函数来加载和绘制图像,可以在图像处理和分析的过程中使用。

5.6 utils\util_calculate_psnr_ssim.py
import cv2
import torch
import numpy as np

class ImageMetrics:
    def __init__(self, input_order='HWC'):
        self.input_order = input_order

    def calculate_psnr(self, img1, img2, crop_border, test_y_channel=False):
        assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
        if self.input_order not in ['HWC', 'CHW']:
            raise ValueError(f'Wrong input_order {self.input_order}. Supported input_orders are ' '"HWC" and "CHW"')
        img1 = self.reorder_image(img1)
        img2 = self.reorder_image(img2)
        img1 = img1.astype(np.float64)
        img2 = img2.astype(np.float64)

        if crop_border != 0:
            img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
            img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

        if test_y_channel:
            img1 = self.to_y_channel(img1)
            img2 = self.to_y_channel(img2)

        mse = np.mean((img1 - img2) ** 2)
        if mse == 0:
            return float('inf')
        return 20. * np.log10(255. / np.sqrt(mse))

    ......

该程序文件是一个用于计算图像质量评估指标的工具文件。主要包含以下几个函数:

  1. calculate_psnr(img1, img2, crop_border, input_order=‘HWC’, test_y_channel=False):计算图像的峰值信噪比(PSNR)指标。

  2. _ssim(img1, img2):计算图像的结构相似性(SSIM)指标。

  3. calculate_ssim(img1, img2, crop_border, input_order=‘HWC’, test_y_channel=False):计算图像的结构相似性(SSIM)指标。

  4. _blocking_effect_factor(im):计算图像的块效应因子。

  5. calculate_psnrb(img1, img2, crop_border, input_order=‘HWC’, test_y_channel=False):计算图像的PSNR-B指标。

  6. reorder_image(img, input_order=‘HWC’):重新排列图像的通道顺序。

  7. to_y_channel(img):将图像转换为Y通道。

  8. bgr2ycbcr(img, y_only=False):将BGR图像转换为YCbCr图像。

这些函数可以用于评估图像处理算法的

6.系统整体结构

整体功能和构架概述:

该图像超分辨率系统的整体功能是实现图像的超分辨率重建。它使用了基于Swin Transformer的模型进行图像超分辨率处理。系统包含了多个程序文件,每个文件负责不同的功能模块。主要的程序文件包括:

  1. main_test_swin2sr.py:用于图像超分辨率重建的测试程序。负责加载模型、预处理图像、进行推理、保存结果和计算评估指标等步骤。

  2. predict.py:图像超分辨率预测的预测器。使用Swin2SR模型进行预测,定义了Predictor类,负责加载模型并进行预测。

  3. ui.py:图像超分辨率的PyQt5界面程序。通过界面输入图像路径和参数,调用predict.py中的函数进行超分辨率处理。

  4. models\network_swin2sr.py:实现Swin Transformer模型的核心组件。定义了多个辅助函数和模块,包括Mlp、WindowAttention和SwinTransformerBlock等。

  5. utils\plots.py:用于绘制图像的工具文件。包含load_img和plot_all函数,用于加载和绘制图像。

  6. utils\util_calculate_psnr_ssim.py:用于计算图像质量评估指标的工具文件。包含多个函数,用于计算PSNR、SSIM和其他指标。

  7. utils_init_.py:空文件,用于标识utils文件夹为Python模块。

下表整理了每个文件的功能:

文件路径 功能
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\main_test_swin2sr.py 图像超分辨率重建的测试程序,包括加载模型、预处理图像、进行推理、保存结果和计算评估指标等步骤
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\predict.py 图像超分辨率预测的预测器,定义了Predictor类,负责加载模型并进行预测
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\ui.py 图像超分辨率的PyQt5界面程序,通过界面输入图像路径和参数,调用predict.py中的函数进行超分辨率处理
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\models\network_swin2sr.py 实现Swin Transformer模型的核心组件,包括辅助函数和模块,如Mlp、WindowAttention和SwinTransformerBlock
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\utils\plots.py 用于绘制图像的工具文件,包含load_img和plot_all函数,用于加载和绘制图像
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\utils\util_calculate_psnr_ssim.py 用于计算图像质量评估指标的工具文件,包含多个函数,用于计算PSNR、SSIM和其他指标
E:\视觉项目\shop\基于Swin_Transformer的图像超分辨率系统\code\utils_init_.py 空文件,用于标识utils文件夹为Python模块

7.Swin_Transformer用于超分辨率重建

参考该博客提出的RefSR工作,主要观点是将Transformer作为一个attention,这样可以更好地将参考图像(Ref)的纹理信息转移到高质图像(HR)中。做法还是比较有意思的,如下图所示,将上采样的LR图像、依次向下/上采样的Ref图像、原始Ref图像中提取的纹理特征分别作为Q、K、V。纹理Transformer包含了4个结构:1)DNN实现的可学习的纹理提取器(learnable texture extractor)2)相关性嵌入模块( relevance embedding)3)用于纹理转换的硬注意力模块(hard-attention)4)用于纹理合成的软注意力模块(soft-attention)。此外整个纹理Transformer模块可以跨尺度的方式进一步堆叠,这使得能够从不同尺度(例如,从1x倍到4x倍放大率)恢复纹理。

基于Swin_Transformer的图像超分辨率系统_第2张图片

网络的整体架构

如下图所示,将多个纹理Transformer(即上图)堆叠、上采下采融合来实现超分。
其中RBS为多个残差Block,CSFI为跨尺度特征集成模块(ross-scale feature integration )

纹理Transformer

基于Swin_Transformer的图像超分辨率系统_第3张图片

即图,介绍一下他的四个组件。
1)DNN实现的可学习的纹理提取器。就是将图像送入DNN,然后DNN可以训练
2)相关性嵌入模块。使用归一化内积计算Q、K之间的相关性。获得矩阵r i , j r_{i,j}r
3)硬注意力。通过h i = a r g m a x ( r i , j ) h_{i}=argmax(r_{i,j})h
4)软注意力。获得软注意力图s i = a r g m a x ( r i , j ) s_{i}=argmax(r_{i,j})s

再分析一下这个公式,当S大的时候,说明当前块和T的相关性大,所以用更多的T的特征,如果S小,则使用更少的参考帧特征。
在这里插入图片描述

损失函数

L1 loss + GAN loss + Percepture Loss

网络结构

基于Swin_Transformer的图像超分辨率系统_第4张图片
1)Shallow Feature Extraction 为一层3x3卷积。
2)HQ Image Reconstruction在SR任务中采用sub-pixel Conv,就是unpixelShuffle。denoise和JPEG去伪影用一层卷积。
3)对STL,就是Transformer的Encoder结构。将输入划分为M ∗ M M*MM∗M个块X,然后每个X映射为QKV,通过多头attention后将输出concat。MLP通过两层FC实现。作者还进行了划窗来避免图像块之间的信息不融合问题。步长为M / 2 M/2M/2

EMHA

主要是在获得QKV之后,将QKV特征分为s组,每组分别进行attention获得输出O,然后将输出Concat,这样可以将大矩阵相乘拆分为多个小矩阵相乘。这也是Transformer常见的减少参数操作。
在这里插入图片描述

HFM

此外该博客的作者还用了一个High-frequencyFiltering Module (HFM)提取高频信息,结构如下,仅供参考。

基于Swin_Transformer的图像超分辨率系统_第5张图片

Microsoft Bing Turing ISR(T-ISR)

Introducing Turing Image Super Resolution: AI powered image enhancements for Microsoft Edge and Bing Maps
这篇不算论文,是微软介绍自家用于Microsoft Edge和Bing Maps上ISR的技术博客。但是效果非常Amazing啊,但缺点是有些地方没有仔细介绍。

设计原则

1)人类视觉为基准(Human eyes as the north star)
广泛使用的指标如PSNR,SSIM并不总是和人眼视觉的直观感受匹配的,同时也需要GT图。我们构建了一个并行评估工具匹配人眼判断,并将这个工具作为north star metric来引导模型训练。(可是作者没介绍这个工具是啥55555)
2)噪声建模(Noise modeling)
开始作者也是将HR图像降质然后构建HR-LR图相对训练。但这样有些case效果好,但是对真实的LR图像不鲁棒。因此随机对输入图像用blurring, compression 和 gaussian noise进行破坏可以恢复细节。
3)Perceptual and GAN loss
仅pixel loss不够,要引入感知和GAN loss,并用权重结合。
4)Transformers for vision
CNN和Transformer各有优缺点,因此未利用他们各自优点,将网络分为Enhance和Zoom,前者使用Transformer,后者使用CNN。(其实这段也没详细介绍各自优缺点是什么。整体四准则很对我胃口啊,果然英雄所见略同hhhh)

DeepEnhance – Cleaning and Enhancing Images

在处理高度压缩和从远程卫星拍摄的航拍照片等very noise图像时,Transformer清理噪声做的很好。如人脸的噪声和处理包含很多纹理的森林的特征就很不同。这是因为大数据集和Transformer卓越的远程记忆能力。我们先使用了一个稀疏Transformer,将其放大以支持非常大的序列长度来“Enhance”图像,产生干净的,crisper和更具吸引力,尺寸相同的图像。有些场景不需要放大图像,那到这里就可以停止了。

基于Swin_Transformer的图像超分辨率系统_第6张图片

8.系统整合

下图完整源码&环境部署视频教程&自定义UI界面

基于Swin_Transformer的图像超分辨率系统_第7张图片

参考博客《基于Swin_Transformer的图像超分辨率系统》

9.参考文献


[1]盘展鸿,朱鉴,迟小羽,等.基于特征融合和注意力机制的图像超分辨率模型[J].计算机应用研究.2022,39(3).DOI:10.19734/j.issn.1001-3695.2021.07.0288 .

[2]邓焱文.基于深度学习的超分辨率重建在人脸识别中的应用[D].2019.

[3]Yu-Qi Liu,Xin Du,Hui-Liang Shen,等.Estimating Generalized Gaussian Blur Kernels for Out-of-Focus Image Deblurring[J].IEEE Transactions on Circuits & Systems for Video Technology.2020,31(3).829-843.DOI:10.1109/TCSVT.2020.2990623 .

[4]Shengxiang Zhang,Gaobo Liang,Shuwan Pan,等.A Fast Medical Image Super Resolution Method Based on Deep Learning Network[J].IEEE Access.2018.712319-12327.DOI:10.1109/ACCESS.2018.2871626 .

[5]Huihui Song,Qingshan Liu,Guojie Wang,等.Spatiotemporal Satellite Image Fusion Using Deep Convolutional Neural Networks[J].IEEE journal of selected topics in applied earth observations & remote sensing.2018,11(3).821-829.DOI:10.1109/JSTARS.2018.2797894 .

[6]Park, S.,Serpedin, E.,Qaraqe, K..Gaussian Assumption: The Least Favorable but the Most Useful [Lecture Notes][J].IEEE Signal Processing Magazine.2013,30(3).183-186.

[7]Mittal, A.,Soundararajan, R.,Bovik, A.C..Making a “Completely Blind” Image Quality Analyzer[J].Signal Processing Letters, IEEE.2013,20(3).209-212.DOI:10.1109/LSP.2012.2227726 .

[8]Ogawa, T.,Haseyama, M..Missing Intensity Interpolation Using a Kernel PCA-Based POCS Algorithm and its Applications[J].IEEE Transactions on Image Processing.2011,20(2).

[9]Yang, J.Wright, J.Huang, T.Ma, Y..Image Super-Resolution Via Sparse Representation[J].IEEE Transactions on Image Processing.2010,19(11).2861-2873.

[10]Bovik A.C.,Zhou Wang,Simoncelli E.P.,等.Image quality assessment: from error visibility to structural similarity[J].IEEE Transactions on Image Processing.2004,13(4).

你可能感兴趣的:(transformer,深度学习,人工智能)