【论文+代码】ZS-N2N实现小样本零网络图像去噪

01、引言

本文方法源于Youssef Mansour和Reinhard Heckel撰写的论文《Zero-Shot Noise2Noise: Efficient Image Denoising without any Data》,该文作者探索了一种不需要任何数据且高效的高效图像去噪方法。

该方法使用两个固定的内核对噪声图像进行卷积,以创建一对降采样的图像。然后用一致性损失训练一个简单的2层CNN,将一个下采样的图像映射到另一个。

本文涵盖主题:方法介绍、代码复现两个主题。

本文主要结果汇总仪表板:

【论文+代码】ZS-N2N实现小样本零网络图像去噪_第1张图片

本期内容『数据+代码』已上传百度网盘。

有需要的朋友可以关注公众号【小Z的科研日常】,后台回复关键词[图像去噪]获取

02、简要介绍

ZS-N2N使用简单的2层网络,在没有任何训练数据或噪声分布知识的情况下,可以以低计算成本实现高质量的图像去噪,对像素独立噪声的去噪效果良好。适合于数据稀缺和计算资源有限的情况下使用。

该方法只对噪声统计进行了最小的假设(像素级的独立性),并且不需要训练数据。不需要明确的噪声模型,因此适用于各种噪声类型,并且可以在噪声分布或水平未知的情况下使用。

对噪声的唯一假设是,它是无结构的,而且平均值为零。

该论文提出的方法是无需数据集和噪声模型,与现有的方法相比,在泛化、去噪质量和计算资源之间实现了更好的权衡。

【论文+代码】ZS-N2N实现小样本零网络图像去噪_第2张图片

如上图所示。我们与标准的零拍基线,包括BM3D,以及最近的基于神经网络的算法DIP [UVL18]和S2S [Qua+20]进行比较。

只有BM3D比我们的方法快,但在非高斯噪声上取得的效果很差。

只有S2S有时优于我们的方法,但速度要慢几个数量级,在低噪声水平上经常失败[KLS22],并且需要集合才能达到可接受的性能。

03、代码复现

首先,我们加载一张测试图像并展示:

img_path = "C:/Users/asus/Desktop/Z/Kodak24/kodim07.png"
clean_img = Image.open(img_path)
clean_img = clean_img.convert("RGB")
clean_img = torch.from_numpy(np.array(clean_img)).permute(2, 0, 1).unsqueeze(0).float() / 255.0

#展示加载的测试图像
show_img = clean_img.squeeze(0).permute((1, 2, 0)).numpy()
plt.imshow(show_img)
plt.axis('off')
plt.title('Original Picture')
plt.show()

原图如下:

【论文+代码】ZS-N2N实现小样本零网络图像去噪_第3张图片

可以看到原图是一张清晰的彩色图像。接下来,我们添加噪声:

noise_type = 'gauss'  # 可选择'gauss' 或 'poiss'
noise_level = 50  # 高斯的像素范围是0-255,裂变的像素范围是0-1。

def add_noise(x, noise_level):
    if noise_type == 'gauss':
        noisy = x + torch.normal(0, noise_level / 255, x.shape)
        noisy = torch.clamp(noisy, 0, 1)

    elif noise_type == 'poiss':
        noisy = torch.poisson(noise_level * x) / noise_level

    return noisy

noisy_img = add_noise(clean_img, noise_level)

# 展示添加噪声后的图像
show_noisy_img = noisy_img.squeeze(0).permute((1, 2, 0)).numpy()
plt.imshow(show_noisy_img)
plt.axis('off')
plt.title(f'{noise_type.capitalize()} Noise ({noise_level})')
plt.show()

可以看到,我们成功地为图像添加了高斯噪声。接下来,我们将图像张量PyTorch 变量 device 指定的设备中,并定义一个2层的CNN网络,用于降噪:

device = 'cuda'
clean_img = clean_img.to(device)
noisy_img = noisy_img.to(device)

class network(nn.Module):
    def __init__(self, n_chan, chan_embed=48):
        super(network, self).__init__()

        self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.conv1 = nn.Conv2d(n_chan, chan_embed, 3, padding=1)
        self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
        self.conv3 = nn.Conv2d(chan_embed, n_chan, 1)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        x = self.conv3(x)

        return x

n_chan = clean_img.shape[1]
model = network(n_chan)
model = model.to(device)

这是该方法的神经网络模型,其中 chan_embed 是嵌入通道数,默认为48。该网络采用了 LeakyReLU 激活函数和卷积操作。

接下来介绍图像对下采样器,它通过对非重叠斑块中的对角线像素进行平均,输出两个空间分辨率为一半的下采样图像,如下图所示。它是通过将图像与两个固定的内核进行卷积实现。其中卷积的跨度为2,并分别应用于每个图像通道。

【论文+代码】ZS-N2N实现小样本零网络图像去噪_第4张图片

接下来,我们定义一个能够对输入图像进行下采样的函数:

def pair_downsampler(img):
    # img has shape B C H W
    c = img.shape[1]

    filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device)
    filter1 = filter1.repeat(c, 1, 1, 1)

    filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device)
    filter2 = filter2.repeat(c, 1, 1, 1)

    output1 = F.conv2d(img, filter1, stride=2, groups=c)
    output2 = F.conv2d(img, filter2, stride=2, groups=c)

    return output1, output2
    

显示噪声图像和其相应的降采样对,请注意降采样的图像的空间分辨率是一半:

img1, img2 = pair_downsampler(noisy_img)

img0 = noisy_img.cpu().squeeze(0).permute(1,2,0)
img1 = img1.cpu().squeeze(0).permute(1,2,0)
img2 = img2.cpu().squeeze(0).permute(1,2,0)

fig, ax = plt.subplots(1, 3,figsize=(15, 15))

ax[0].imshow(img0)
ax[0].set_title('Noisy Img')

ax[1].imshow(img1)
ax[1].set_title('First downsampled')

ax[2].imshow(img2)
ax[2].set_title('Second downsampled')

【论文+代码】ZS-N2N实现小样本零网络图像去噪_第5张图片

接下来,定义一个损失函数,它对输入图像和网络输出进行比较,并计算它们之间的均方误差。这个均方误差表示了噪声被降低的程度:

def mse(gt: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
    loss = torch.nn.MSELoss()
    return loss(gt, pred)


def loss_func(noisy_img):
    noisy1, noisy2 = pair_downsampler(noisy_img)

    pred1 = noisy1 - model(noisy1)
    pred2 = noisy2 - model(noisy2)

    loss_res = 1 / 2 * (mse(noisy1, pred2) + mse(noisy2, pred1))

    noisy_denoised = noisy_img - model(noisy_img)
    denoised1, denoised2 = pair_downsampler(noisy_denoised)

    loss_cons = 1 / 2 * (mse(pred1, denoised1) + mse(pred2, denoised2))

    loss = loss_res + loss_cons

    return loss

现在,开始训练神经网络,使用Adam优化器并迭代2000次:

def train(model, optimizer, noisy_img):
    loss = loss_func(noisy_img)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


def test(model, noisy_img, clean_img):
    with torch.no_grad():
        pred = torch.clamp(noisy_img - model(noisy_img), 0, 1)
        MSE = mse(clean_img, pred).item()
        PSNR = 10 * np.log10(1 / MSE)

    return PSNR


def denoise(model, noisy_img):
    with torch.no_grad():
        pred = torch.clamp(noisy_img - model(noisy_img), 0, 1)

    return pred

max_epoch = 2000     # training epochs
lr = 0.001           # learning rate
step_size = 1500     # number of epochs at which learning rate decays
gamma = 0.5          # factor by which learning rate decays

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

#开始去噪
for epoch in range(max_epoch):
    # 初始化进度条
    progress_bar = "[" + "-" * 20 + "]"

    # 训练
    train(model, optimizer, noisy_img)

    # 更新进度条
    progress_bar = list(progress_bar)
    progress_bar[(epoch * 20) // max_epoch] = "#"
    progress_bar = "".join(progress_bar)

    # 打印进度条和当前时间点
    print(f"\r{progress_bar} Epoch: {epoch + 1}/{max_epoch}", end="")

训练完成后,我们可以将网络应用于原始噪声图像并显示输出结果:

# 去噪后图像的PSNR
PSNR = test(model, noisy_img, clean_img)
print(PSNR)

#显示原图、加入噪音和去噪的图像
denoised_img = denoise(model, noisy_img)

denoised = denoised_img.cpu().squeeze(0).permute(1,2,0)
clean = clean_img.cpu().squeeze(0).permute(1,2,0)
noisy = noisy_img.cpu().squeeze(0).permute(1,2,0)
fig, ax = plt.subplots(1, 3,figsize=(15, 15))

ax[0].imshow(clean)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].set_title('Ground Truth')

ax[1].imshow(noisy)
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].set_title('Noisy Img')
noisy_psnr = 10*np.log10(1/mse(noisy_img,clean_img).item())
ax[1].set(xlabel= str(round(noisy_psnr,2)) + ' dB')

ax[2].imshow(denoised)
ax[2].set_xticks([])
ax[2].set_yticks([])
ax[2].set_title('Denoised Img')
ax[2].set(xlabel= str(round(PSNR,2)) + ' dB')
fig.savefig('output.png', dpi=300, bbox_inches='tight')

【论文+代码】ZS-N2N实现小样本零网络图像去噪_第6张图片

可以看到,神经网络成功地降低了图像中的噪声。

你可能感兴趣的:(opencv,图像处理,python,深度学习)