在由浅入深理解神经网络 1 一个简单到极致的神经网络中, 我们已经发现了训练神经网络最重要的一件事, 那就是求梯度, 然后优化算法利用梯度来调整网络参数. 我们重写一下前面提到的一个通用的神经网络:
y = f ( x ; θ ) y=f(x;\theta) y=f(x;θ)
考虑到训练集 { X , Y } \{X,Y\} {X,Y}和损失函数 l o s s loss loss: 我们的目标是最小化损失:
l o s s ( Y , Y ^ ) = l o s s ( Y , f ( X , θ ) ) loss(Y,\hat Y)=loss(Y, f(X,\theta)) loss(Y,Y^)=loss(Y,f(X,θ))
注意损失函数一定是个标量函数, 即是说它的输出一定是一个实数,因为只有实数才可以直接比较大小, 才谈得上最小化.
为了这个最小化的目标, 我们需要计算下列的偏导:
∂ l o s s ( Y , f ( X , θ ) ) ∂ θ \frac{\partial loss(Y, f(X,\theta))}{\partial \theta} ∂θ∂loss(Y,f(X,θ))
这是本节内容要探讨的对神经网络来说至关重要的问题:如何计算任意神经网络的这个偏导。
如果你对自己的数学没有信心,可以直接跳转到结语部分。
张量(Tensor)的概念在神经网络的领域非常重要, 这其中的缘由是我们面对的问题不全是上一篇文章中说到的案例中的输入一个数字输出一个数字的情况. 比如说一张黑白照片, 数据可以是一个128 × \times × 128的灰度矩阵, 一个视频可以是一个长 × \times ×宽 × \times ×时间帧 × \times ×RGB通道的数据. 所以我们需要用一个抽象的概念来统合这些数据, 它们全部是张量. 张量可以理解为任意维矩阵, 我们常用的标量是0维张量, 数列是1维张量, 矩阵是2维张量, 当然在神经网络这个领域, 我们可以处理的张量是任意维度的。显然我们常用的函数实际上都可以被归类成张量函数, 它们输入若干个张量,输出若干个张量。
神经网络最终的呈现结果是一个运算图,这个图的每个节点对应的是一个运算(算子,层,模块,函数,他们的本质都是一个运算),这个图的每根线都对应着一个张量,它标示了每个张量从哪个运算节点输出, 输入给哪个运算节点。
比如说下面这种形式的一个运算图,这个图输入的是一个尺寸为 1 × 3 × 224 × 224 1\times3\times224\times224 1×3×224×224的四维张量,
每个方块都对应着一个特定的运算, 这些运算具体是在执行什么,后面的文章会提到。目前我们只需要知道这种图中的方块是运算,带箭头的连线代表的是张量。
前面说过,神经网络是一个带参数的张量函数,而神经网络又是由若干个小的张量函数组成的,这些小的张量函数有许许多多的名字。它们常常被称之为层(layer),比如全连接层,卷积层这些,这种叫法在论文中很常见;有时候它们会被称之为节点(node)或者算子(operator),这种叫法常见于onnx这类的框架,因为它们把神经网络当作一个运算图看待。就像上一节的说法,运算图就是用节点和张量流组成的。
现在我们举个简单的例子:
y = f ( x ; θ ) = f 2 ( f 1 ( x ; θ 1 ) ; θ 2 ) y=f(x;\theta)=f_2(f_1(x;\theta_1);\theta_2) y=f(x;θ)=f2(f1(x;θ1);θ2)
这可以表示由两个节点组合出的简易神经网络,为了计算出y,我们需要先计算 f 1 ( x ; θ 1 ) f_1(x;\theta_1) f1(x;θ1)它的结果作为 f 2 f_2 f2的输入。这玩意说白了就是初高中学的复合函数, 比如 y = sin ( x 2 ) y=\sin(x^2) y=sin(x2)就是符合上面这个例子的函数。
到这里我们可以再给神经网络下一个新的定义,神经网络是一个带参数的复合张量函数。
值得注意的是,像前面这个例子一样逐级嵌套张量函数,并不能形式化的表示所有神经网络,现代的神经网络往往有着更加复杂的形式,它的每个节点(比如 f 1 f_1 f1和 f 2 f_2 f2)都可以有多个输入和多个输出,这些输入和输出可以连接到更多的节点,就像上一章中的图一样。
好了,现在回到一开始的最重要的任务, 我们分析这么多的目的是我们想要求出来一个复杂的运算图,它的损失函数关于所有参数的偏导,用梯度下降根据这些偏导调整参数的值,就是训练的过程。让我们先回忆一些基础的微分知识,神经网络的整个知识体系涉及的数学并不高深,除了这部分的微分知识,还有后面会提到的矩阵论。
先让我们复习一下初高中简单的符合函数求导问题,求 y = sin ( x 2 ) y=\sin(x^2) y=sin(x2)的导数, 当时我们学到的解法是先求 x 2 x^2 x2的导数 2 x 2x 2x, 求 sin ′ ( x ) = cos ( x ) \sin '(x)=\cos(x) sin′(x)=cos(x), 用所谓的链式法则
g = x 2 y = sin ( g ) ∂ y ∂ x = ∂ y ∂ g ∂ g ∂ x = cos ( g ) 2 x = 2 x cos ( x 2 ) g=x^2\\ y=\sin (g)\\ \frac{\partial y}{\partial x}=\frac{\partial y}{\partial g}\frac{\partial g}{\partial x}=\cos(g)2x=2x\cos(x^2) g=x2y=sin(g)∂x∂y=∂g∂y∂x∂g=cos(g)2x=2xcos(x2)
显然的,链式法则将复合函数的求导运算拆分成了若干个简单函数的求导,这个对神经网络这种动则嵌套几十次上百次的超大复合函数是至关重要的,链式法则的存在,使得深层的神经网络的训练成为可能。
神经网络的链式法则要稍微复杂一点,因为它的节点不一定是单输入单输出的。对于多个输入的多元函数一个一般的例子:
y = f ( g ( x ) , h ( x ) ) y=f(g(x),h(x)) y=f(g(x),h(x))
这个函数的偏导:
∂ y ∂ x = ∂ y ∂ g ∂ g ∂ x + ∂ y ∂ h ∂ h ∂ x \frac{\partial y}{\partial x}=\frac{\partial y}{\partial g}\frac{\partial g}{\partial x}+\frac{\partial y}{\partial h}\frac{\partial h}{\partial x} ∂x∂y=∂g∂y∂x∂g+∂h∂y∂x∂h
在完成最终的偏导运算之前,我们先引进前向传播的概念。还是前面的例子,要获得y的结果,我们要先由输入x出发,经过第一个算子 f 3 f_3 f3得到h,再由h经过算子 f 2 f_2 f2得到g,最后得到y,这种一级一级获得中间结果,最终得到输出的过程叫做前向传播,整个网络的中间结果就像是在由输入端慢慢向输出传播一样。很显然,前向传播就是在完成我们前面说的推理。
反向传播恰恰相反,它将梯度的计算从输出节点传播到输入,逐步完成所有偏导的计算。我们重新定义一个这样的简单含参复合函数:
y = f 1 ( f 2 ( f 3 ( x ; θ 3 ) ; θ 2 ) ; θ 1 ) y=f_1(f_2(f_3(x;\theta_3);\theta_2);\theta_1) y=f1(f2(f3(x;θ3);θ2);θ1)我们需要求出
∂ y ∂ θ 1 , ∂ y ∂ θ 2 , ∂ y ∂ θ 3 \frac{\partial y}{\partial \theta_1},\frac{\partial y}{\partial \theta_2},\frac{\partial y}{\partial \theta_3} ∂θ1∂y,∂θ2∂y,∂θ3∂y这样梯度下降算法就可以利用他们来修正各自的参数。
类似上面的例子,我们不妨令
h = f 2 ( x ; θ 3 ) g = f 2 ( h ; θ 2 ) y = f 1 ( g ; θ 1 ) h=f_2(x;\theta_3)\\ g=f_2(h;\theta_2)\\ y=f_1(g;\theta_1) h=f2(x;θ3)g=f2(h;θ2)y=f1(g;θ1)
前面提到了前向传播,所以在x给出之后,h g y都是已经算出来了, 这对后面的计算有用。利用前面的知识, 不难发现:
∂ y ∂ θ 1 = ∂ f 1 ( g , θ 1 ) ∂ θ 1 \frac{\partial y}{\partial \theta_1}=\frac{\partial f_1(g,\theta_1)}{\partial\theta_1} ∂θ1∂y=∂θ1∂f1(g,θ1)
∂ y ∂ θ 2 = ∂ f 1 ( g , θ 1 ) ∂ g ∂ f 2 ( h , θ 2 ) ∂ θ 2 \frac{\partial y}{\partial \theta_2}=\frac{\partial f_1(g,\theta_1)}{\partial g}\frac{\partial f_2(h,\theta_2)}{\partial \theta_2} ∂θ2∂y=∂g∂f1(g,θ1)∂θ2∂f2(h,θ2)
∂ y ∂ θ 3 = ∂ f 1 ( g , θ 1 ) ∂ g ∂ f 2 ( h , θ 2 ) ∂ h ∂ f 3 ( h , θ 3 ) θ 3 \frac{\partial y}{\partial \theta_3}=\frac{\partial f_1(g,\theta_1)}{\partial g}\frac{\partial f_2(h,\theta_2)}{\partial h}\frac{\partial f_3(h,\theta_3)}{\theta_3} ∂θ3∂y=∂g∂f1(g,θ1)∂h∂f2(h,θ2)θ3∂f3(h,θ3)
从上面的公式可以发现,我们实际运算的时候只需要把每个节点 f 1 , f 2 , f 3 f_1,f_2,f_3 f1,f2,f3自身的输入和参数的偏导算好,计算全图的偏导时,从输出端不断向前检索直到检索到输入,整个路径上的偏导联乘即可. 实际上如果你仔细研究pytorch或者tensorflow的自动求导器, 可以看到和这个思路对应的代码。所谓反向传播这个名字也是来源于这种独特的链式求导。
最后我们要计入损失函数,损失函数一般是不带参数的,我们可以把它当作是整个图的最后一个节点,由于它不含参数,因此对于 l o s s ( x ) loss(x) loss(x)
我们只需要计算 ∂ l o s s ( x ) ∂ x \frac{\partial loss(x)}{\partial x} ∂x∂loss(x), 由于反向传播的特点,作为最后一个节点的它会作为所有参数偏导链条的第一项。
前面我们举的例子都停留在前馈网络的范围,还不涉及多输入多输出,现在我们来举一个多输入多输出的简单例子,来展示一下如何用反向传播的形式完成链式法则的递推。我们在图中插入一个新的二输入二输出函数:
( y 1 , y 2 ) = f m ( x 1 , x 2 ; m , n ) = ( m x 1 x 2 , n x 1 2 x 2 2 ) (y_1,y_2)=f_m(x_1,x_2;m,n)=(mx_1x_2,nx_1^2x_2^2) (y1,y2)=fm(x1,x2;m,n)=(mx1x2,nx12x22)
这个函数包含两个输入 x 1 , x 2 x_1,x_2 x1,x2,两个输出 y 1 , y 2 y_1,y_2 y1,y2, 以及两个参数 m , n m,n m,n
构建下面这样的 一个简易的运算图。
我们把这个图描述的运算公式化:
t 1 = a x + b t 2 = c x + d t 3 = m t 1 t 2 t 4 = n t 1 2 t 2 2 y = t 3 + t 4 t_1=ax+b\\ t_2=cx+d\\ t_3=mt_1t_2\\ t_4=nt_1^2t_2^2\\ y=t_3+t_4 t1=ax+bt2=cx+dt3=mt1t2t4=nt12t22y=t3+t4
现在我们分析每个节点的运算:
对于最后一个加法运算,它是一个双输入单输出的函数,没有参数,因此不需要计算参数的偏导
∂ y ∂ t 3 = 1 , ∂ y ∂ t 4 = 1 \frac{\partial y}{\partial t_3}=1,\frac{\partial y}{\partial t_4}=1 ∂t3∂y=1,∂t4∂y=1
对于 f m f_m fm函数,需要计算的东西比较多:
∂ t 4 ∂ t 1 = 2 n t 1 t 2 2 , ∂ t 4 ∂ t 2 = 2 n t 2 t 1 2 ∂ t 3 ∂ t 1 = m t 2 , ∂ t 3 ∂ t 2 = m t 1 ∂ t 4 ∂ n = t 1 2 t 2 2 , ∂ t 3 ∂ m = t 1 t 2 \frac{\partial t_4}{\partial t_1}=2nt_1t_2^2, \frac{\partial t_4}{\partial t_2}=2nt_2t_1^2\\ \frac{\partial t_3}{\partial t_1}=mt_2,\frac{\partial t_3}{\partial t_2}=mt_1\\ \frac{\partial t_4}{\partial n}=t_1^2 t_2^2,\frac{\partial t_3}{\partial m}=t_1t_2\\ ∂t1∂t4=2nt1t22,∂t2∂t4=2nt2t12∂t1∂t3=mt2,∂t2∂t3=mt1∂n∂t4=t12t22,∂m∂t3=t1t2
最前方的两个简单的线性函数:
∂ t 2 ∂ c = x , ∂ t 2 ∂ d = 1 ∂ t 1 ∂ a = x , ∂ t 1 ∂ b = 1 \frac{\partial t_2}{\partial c}=x, \frac{\partial t_2}{\partial d}=1\\ \frac{\partial t_1}{\partial a}=x,\frac{\partial t_1}{\partial b}=1\\ ∂c∂t2=x,∂d∂t2=1∂a∂t1=x,∂b∂t1=1
依据链式法则的原则, 可以依次获得我们想要的偏导:
∂ y ∂ n = ∂ y ∂ t 4 ∂ t 4 ∂ n = t 1 2 t 2 2 ∂ y ∂ m = ∂ y ∂ t 3 ∂ t 3 ∂ m = t 1 t 2 ∂ y ∂ a = ∂ y ∂ t 3 ∂ t 3 ∂ t 1 ∂ t 1 ∂ a + ∂ y ∂ t 4 ∂ t 4 ∂ t 1 ∂ t 1 ∂ a = m t 2 x + 2 n t 1 t 2 2 x ∂ y ∂ b = ∂ y ∂ t 3 ∂ t 3 ∂ t 1 ∂ t 1 ∂ b + ∂ y ∂ t 4 ∂ t 4 ∂ t 1 ∂ t 1 ∂ b = m t 2 + 2 n t 1 t 2 2 ∂ y ∂ c = ∂ y ∂ t 3 ∂ t 3 ∂ t 2 ∂ t 2 ∂ c + ∂ y ∂ t 4 ∂ t 4 ∂ t 2 ∂ t 2 ∂ c = m t 1 x + 2 n t 2 t 1 2 x ∂ y ∂ d = ∂ y ∂ t 3 ∂ t 3 ∂ t 2 ∂ t 2 ∂ d + ∂ y ∂ t 4 ∂ t 4 ∂ t 2 ∂ t 2 ∂ d = m t 1 + 2 n t 2 t 1 2 \frac{\partial y}{\partial n}=\frac{\partial y}{\partial t_4}\frac{\partial t_4}{\partial n}=t_1^2 t_2^2\\ \frac{\partial y}{\partial m}=\frac{\partial y}{\partial t_3}\frac{\partial t_3}{\partial m}=t_1t_2\\ \frac{\partial y}{\partial a}=\frac{\partial y}{\partial t_3}\frac{\partial t_3}{\partial t_1}\frac{\partial t_1}{\partial a}+\frac{\partial y}{\partial t_4}\frac{\partial t_4}{\partial t_1}\frac{\partial t_1}{\partial a}=mt_2x+2nt_1t_2^2x\\ \frac{\partial y}{\partial b}=\frac{\partial y}{\partial t_3}\frac{\partial t_3}{\partial t_1}\frac{\partial t_1}{\partial b}+\frac{\partial y}{\partial t_4}\frac{\partial t_4}{\partial t_1}\frac{\partial t_1}{\partial b}=mt_2+2nt_1t_2^2\\ \frac{\partial y}{\partial c}=\frac{\partial y}{\partial t_3}\frac{\partial t_3}{\partial t_2}\frac{\partial t_2}{\partial c}+\frac{\partial y}{\partial t_4}\frac{\partial t_4}{\partial t_2}\frac{\partial t_2}{\partial c}=mt_1x+2nt_2t_1^2x\\ \frac{\partial y}{\partial d}=\frac{\partial y}{\partial t_3}\frac{\partial t_3}{\partial t_2}\frac{\partial t_2}{\partial d}+\frac{\partial y}{\partial t_4}\frac{\partial t_4}{\partial t_2}\frac{\partial t_2}{\partial d}=mt_1+2nt_2t_1^2\\ ∂n∂y=∂t4∂y∂n∂t4=t12t22∂m∂y=∂t3∂y∂m∂t3=t1t2∂a∂y=∂t3∂y∂t1∂t3∂a∂t1+∂t4∂y∂t1∂t4∂a∂t1=mt2x+2nt1t22x∂b∂y=∂t3∂y∂t1∂t3∂b∂t1+∂t4∂y∂t1∂t4∂b∂t1=mt2+2nt1t22∂c∂y=∂t3∂y∂t2∂t3∂c∂t2+∂t4∂y∂t2∂t4∂c∂t2=mt1x+2nt2t12x∂d∂y=∂t3∂y∂t2∂t3∂d∂t2+∂t4∂y∂t2∂t4∂d∂t2=mt1+2nt2t12
值得注意的是越是接近输入端的参数,链式求导的链条越长, 而且长的链条和短的链条有重合的部分,出于降低计算量的考虑,我们很容易想象实施链式求导的方案,那就是从输出位置开始向前逐步获取每个节点的偏导,这样可以避免每个偏导的计算都需要重复计算整个链条。所谓的反向传播就是形象的描述这一个过程,偏导就好像从输出端逐步被传播到了整个网络直到输入端。
经过了上面的这些令人不适的数学之后,我们论证了一件事, 那就是我们有办法在计算机上计算出来任意神经网络的输出关于它的全体可变参数之间的偏导,前提是我们要约定好作为神经网络最基层的算子的偏导公式(这其实是很容易的)。还有一件事就是,要做到这件事需要耗费的运算量是可控的,如果你对算法比较敏感,会发现反向传播的复杂度是多项式级的,这一点至关重要,因为它意味着增加网络的规模不会带来指数级别的运算量。也就是说,当人类的计算机的算力达到一定程度,人工智能爆发就是顺其自然的事情。