空间转换网络——Spatial Transformer Networks(论文笔记)

本文是对论文《Spatial Transformer Networks》的学习梳理总结,主要参考文献为论文原文——《Spatial Transformer Networks》以及一篇博客——论文笔记:Spatial Transformer Networks(空间变换网络)

一、CNN存在的问题

CNN定义了非常强大的分类模型,但是仍然受到缺乏在计算和参数效率上对输入数据空间不变性能力的限制。即,当输入图像因随机平移、缩放、旋转、混乱而失真时,CNN模型的分类准确率将会下降。

二、本文的任务

通过对现有的CNN结构进行改进,从而解决上面提到的问题。

三、解决问题的方法

Spatial Transformer

文中通过在CNN模型中插入一个新的可训练的模块——Spatial Transformer,该模块在特定输入的条件下对传入的特征图进行相应的空间变换(通过训练确定特定输入对应的空间变换),导致模型学习平移、缩放、旋转和更普遍的变形的不变性,从而在多个基准和许多类别的变换上获得最先进的性能。

Spatial Transformer模块包含3个部分:Localisation Network,Parameterised Sampling Grid,Differentiable Image Sampling。见下图
空间转换网络——Spatial Transformer Networks(论文笔记)_第1张图片
它完成的是一个将输入特征图进行一定的变换的过程,而具体如何变换,是通过在训练过程中学习来的,更通俗地将,该模块在训练阶段学习如何对输入数据进行变换更有益于模型的分类,然后在测试阶段应用已经训练好的网络对输入数据进行执行相应的变换,从而提高模型的识别率。下面具体介绍Spatial Transformer的三个部分。

1. Localisation Network

该部分获取输入特征图U,并通过多个隐藏层回归输出应用于该特征图的空间变换参数 θ \theta θ——这给出了以输入为条件的变换 θ \theta θ=f(U)。
θ的规模取决于具体的变换,当变换取二维仿射变换时,θ是一个6维(2*3)的矩阵:
放射变换对应的变换参数

2. Parameterised Sampling Grid

该层利用Localisation 层输出的变换参数 θ \theta θ,将输入的特征图进行变换,以仿射变换【放射变换和下面要用到的双线性插值参考博客 论文笔记:Spatial Transformer Networks中的仿射变换和双线性插值】为例,即输出特征图上某一位置(xti,yti)根据变换参数θ映射到输入特征图上某一位置(xsi,ysi),具体如下:
在这里插入图片描述
但此时往往(xsi,ysi)会落在原始输入特征图的几个像素点中间部分,所以需要利用双线性插值来计算出对应该点的灰度值。文中给出了该层两个例子,如下图:
空间转换网络——Spatial Transformer Networks(论文笔记)_第2张图片

3. Differentiable Image Sampling

为了对输入特征图进行空间变换,采样器必须与输入特征图U一起获取采样点T θ \theta θ(G)的集合,并生成采样的输出特征图。该部分文中对输出特征图上的任一个点和输入特征图上的像素点灰度值建立的具体联系进行公式表达(采用双线性插值法):
在这里插入图片描述
另外,我们前面强调该ST(Spatial Transformer)模块是可训练的。因此,上述公式对Ucnm和(xsi,ysi)是可导的,也就是说,ST的变换过程是可以对损失进行反向传播从而在网络中不断训练来修正参数的。具体的求导结果见论文,此处不再赘述。

四、实验

在文章最后,他们利用以上研究成果分别进行了扭曲的MNIST手写数字识别、街景门牌号码识别、细粒度的鸟类分类实验。实验结果均表明,通过插入Spatial Transformer模块可使CNN模型分类器具有更好的空间不变性。此外实验表明,Spatial Transformer模块还起到了一定的注意力机制的作用。

参考文献

1.《Spatial Transformer Networks》
2. 论文笔记:Spatial Transformer Networks(空间变换网络)
3. 论文笔记:Spatial Transformer Networks中的仿射变换和双线性插值

你可能感兴趣的:(学习笔记,计算机视觉,深度学习)