这一小节的主要目的:为什么要使用Variational Inference,Inference到底有什么用。机器学习,我们可以从频率角度和贝叶斯角度两个角度来看,其中频率角度可以被解释为优化问题,贝叶斯角度可以被解释为积分问题。
为什么说频率派角度的分析是一个优化问题呢?从回归和SVM两个例子上进行分析。数据集描述为: D = { ( x i , y i ) } i = 1 N , x i ∈ R p , y i ∈ R 1 D = \{ (x_i,y_i) \}_{i=1}^N,x_i \in \mathbb{R}^p,y_i \in \mathbb{R}^1 D={(xi,yi)}i=1N,xi∈Rp,yi∈R1。
从贝叶斯的角度来说,这就是一个积分问题,为什么呢?从Bayes公式的表达看:
P ( θ ∣ x ) = P ( x ∣ θ ) P ( θ ) P ( x ) (12.1.8) P(\theta|x) = \frac{P(x|\theta)P(\theta)}{P(x)}\tag{12.1.8} P(θ∣x)=P(x)P(x∣θ)P(θ)(12.1.8)
其中, P ( θ ∣ x ) P(\theta|x) P(θ∣x)称为后验公式, P ( x ∣ θ ) P(x|\theta) P(x∣θ)称为似然函数, P ( θ ) P(\theta) P(θ)称为先验分布, P ( x ) P(x) P(x)为已知的概率分布,并且 P ( x ) = ∫ θ P ( x ∣ θ ) P ( θ ) d θ P(x) = \int_{\theta}P(x|\theta)P(\theta)d\theta P(x)=∫θP(x∣θ)P(θ)dθ。贝叶斯角度分为 推 断 ( I n f e r e n c e ) \color{red}推断(Inference) 推断(Inference)和 决 策 \color{red}决策 决策。
数据
有以下数据:
记 z z z为隐变量和参数的集合。接着变换概率 p ( x ) p(x) p(x)的形式然后引入分布 q ( z ) q(z) q(z):
l o g p ( x ) = l o g p ( x , z ) − l o g p ( z ∣ x ) = l o g p ( x , z ) q ( z ) − l o g p ( z ∣ x ) q ( z ) (12.2.1) \color{blue}log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)}\tag{12.2.1} logp(x)=logp(x,z)−logp(z∣x)=logq(z)p(x,z)−logq(z)p(z∣x)(12.2.1)
公式简化
对公式(12.2.1)进行简化,式子两边同时对 q ( z ) q(z) q(z)求积分(期望):
左 边 = ∫ z q ( z ) ⋅ l o g p ( x ∣ θ ) d z = l o g p ( x ∣ θ ) ∫ z q ( z ) d z = l o g p ( x ∣ θ ) (12.2.2) 左边=\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta )\tag{12.2.2} 左边=∫zq(z)⋅logp(x∣θ)dz=logp(x∣θ)∫zq(z)dz=logp(x∣θ)(12.2.2)
右 边 = ∫ z q ( z ) l o g p ( x , z ∣ θ ) q ( z ) d z ⏟ E L B O ( e v i d e n c e l o w e r b o u n d ) − ∫ z q ( z ) l o g p ( z ∣ x , θ ) q ( z ) d z ⏟ K L ( q ( z ) ∣ ∣ p ( z ∣ x , θ ) ) = L ( q ) ⏟ 变 分 + K L ( q ∣ ∣ p ) ⏟ ≥ 0 (12.2.3) 右边=\underset{ELBO(evidence\; lower\; bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}=\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}}\tag{12.2.3} 右边=ELBO(evidencelowerbound) ∫zq(z)logq(z)p(x,z∣θ)dzKL(q(z)∣∣p(z∣x,θ)) −∫zq(z)logq(z)p(z∣x,θ)dz=变分 L(q)+≥0 KL(q∣∣p)(12.2.3)
Evidence Lower Bound (ELBO)是变分, L ( q ) L(q) L(q)和 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)被记为:
{ L ( q ) = ∫ z q ( z ) log p ( x , z ∣ θ ) q ( z ) d z K L ( q ∣ ∣ p ) = − ∫ z q ( z ) log p ( z ∣ x ) q ( z ) d z \color{blue}\{ \begin{array}{ll}L(q)&=\int_z q(z)\log\ \frac{p(x,z|\theta)}{q(z)}dz\\ KL(q||p)&= - \int_z q(z)\log\ \frac{p(z|x)}{q(z)}dz \end{array} {L(q)KL(q∣∣p)=∫zq(z)log q(z)p(x,z∣θ)dz=−∫zq(z)log q(z)p(z∣x)dz
p ( x ) p(x) p(x)是个定值,我们的目的是寻找一个使得 q ( z ) q(z) q(z)与 p ( z ∣ x , θ ) p(z|x,\theta) p(z∣x,θ)更接近,也就是使 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)越小越好,也就是要使 L ( q ) L(q) L(q)越大越好:
q ~ ( z ) = arg max q ( z ) L ( q ) ⇒ q ~ ( z ) ≈ p ( z ∣ x ) (12.2.4) \color{blue}\tilde{q}(z)=\arg\underset{q(z)}{\max}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x)\tag{12.2.4} q~(z)=argq(z)maxL(q)⇒q~(z)≈p(z∣x)(12.2.4)
- L ( q ) \color{red}L(q) L(q)并非普通的函数,而是以函数 q q q为自变量的函数,这就是 泛 函 \color{red}泛函 泛函。泛函可以看成是函数概念的推广,而变分方法是处理泛函的数学领域,和处理函数的普通微积分相对。
- 变 分 法 最 终 寻 求 的 是 极 值 函 数 : 它 们 使 得 泛 函 取 得 极 大 或 极 小 值 。 \color{red}变分法最终寻求的是极值函数:它们使得泛函取得极大或极小值。 变分法最终寻求的是极值函数:它们使得泛函取得极大或极小值。
模型求解
平均场理论:把多维变量的不同维度分为 M M M组,组与组之间是相互独立的:
q ( z ) = ∏ i = 1 M q i ( z i ) (12.2.5) \color{red}q(z)=\prod_{i=1}^{M}q_{i}(z_{i})\tag{12.2.5} q(z)=i=1∏Mqi(zi)(12.2.5)
在这种分解的思想中,我们每次只考虑第 j \color{blue}j j个分布,那么令 q i ( 1 , 2 , ⋯ , j − 1 , j + 1 , ⋯ , M ) \color{blue}q_i(1,2,\cdots,j-1,j+1,\cdots,M) qi(1,2,⋯,j−1,j+1,⋯,M)个分布 f i x e d \color{blue}fixed fixed。将 L ( q ) L(q) L(q)写作两部分:
L ( q ) = ∫ z q ( z ) l o g p ( x , z ) d z ⏟ ① − ∫ z q ( z ) l o g q ( z ) d z ⏟ ② (12.2.6) L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}}\tag{12.2.6} L(q)=① ∫zq(z)logp(x,z)dz−② ∫zq(z)logq(z)dz(12.2.6)
平均场理论解释
平均场理论:把多维变量的不同维度分为 M M M组,组与组之间是相互独立的:
q ( z ) = ∏ i = 1 M q i ( z i ) (12.3.2) \color{red}q(z)=\prod_{i=1}^{M}q_{i}(z_{i})\tag{12.3.2} q(z)=i=1∏Mqi(zi)(12.3.2)
注: z i 表 示 的 不 是 一 个 数 , 而 是 一 个 数 据 维 度 的 集 合 , {\color{red} z_i表示的不是一个数,而是一个数据维度的集合,} zi表示的不是一个数,而是一个数据维度的集合, 它 表 示 的 不 是 一 个 维 度 , 而 是 一 个 类 似 的 最 大 团 , 也 就 是 多 个 维 度 凑 在 一 起 。 {\color{red} 它表示的不是一个维度,而是一个类似的最大团,也就是多个维度凑在一起。} 它表示的不是一个维度,而是一个类似的最大团,也就是多个维度凑在一起。
数学符号规范化(仔细与上一节进行对比)
E L B O \color{blue}ELBO ELBO(第十讲:公式(10.5.6)):
E q ( z ) [ log p θ ( x ( i ) , z ) q ( z ) ] = E q ( z ) [ log p θ ( x ( i ) , z ) ] + H ( q ( z ) ) (12.3.4) \mathbf{E}_{q(z)}\left[ \log \frac{p_{\theta}(x^{(i)},z)}{q(z)} \right] = \mathbf{E}_{q(z)}\left[ \log p_{\theta}(x^{(i)},z) \right]+ H(q(z))\tag{12.3.4} Eq(z)[logq(z)pθ(x(i),z)]=Eq(z)[logpθ(x(i),z)]+H(q(z))(12.3.4)
K L \color{blue}KL KL(第十讲:公式(10.5.2)):
K L ( q ∣ ∣ p ) = ∫ q ( z ) ⋅ log q ( z ) p θ ( z ∣ x ( i ) ) d z (12.3.5) KL(q||p) = \int q(z)\cdot \log \frac{q(z)}{p_{\theta}(z|x^{(i)})} dz\tag{12.3.5} KL(q∣∣p)=∫q(z)⋅logpθ(z∣x(i))q(z)dz(12.3.5)
log q j ( z j ) \color{blue}\log\;q_j(z_j) logqj(zj)(本节:公式(12.2.16))
log q j ( z j ) = E ∏ i ≠ j q i ( z i ) [ log p θ ( x ( i ) , z ) ] + C = ∫ q 1 ∫ q 2 ⋯ ∫ q j − 1 ∫ q j + 1 ⋯ ∫ q M q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q M log p θ ( x ( i ) , z ) d q 1 d q 2 ⋯ d q j − 1 d q j + 1 ⋯ d q M (12.3.6) \begin{aligned} & \log q_j(z_j)\\ & = \mathbf{E}_{\prod_{i \neq j} q_i(z_i)}\left[ \log p_{\theta} (x^{(i)},z) \right] + C \\ & = \int_{q_1} \int_{q_2} \cdots \int_{q_{j-1}}\int_{q_{j+1}} \cdots \int_{q_{M}} q_1q_2\cdots q_{j-1}q_{j+1} \cdots q_M \log p_{\theta} (x^{(i)},z)dq_1dq_2 \cdots dq_{j-1}dq_{j+1} \cdots dq_{M} \\ \end{aligned}\tag{12.3.6} logqj(zj)=E∏i=jqi(zi)[logpθ(x(i),z)]+C=∫q1∫q2⋯∫qj−1∫qj+1⋯∫qMq1q2⋯qj−1qj+1⋯qMlogpθ(x(i),z)dq1dq2⋯dqj−1dqj+1⋯dqM(12.3.6)
VI算法的具体求解
根据 公 式 ( 12.2.16 ) 公式(12.2.16) 公式(12.2.16)使用迭代算法来进行求解:
q ^ 1 ( z 1 ) = ∫ q 2 ⋯ ∫ q M q 2 ⋯ q M [ log p θ ( x ( i ) , z ) ] d q 2 ⋯ d q M q ^ 2 ( z 2 ) = ∫ q ^ 1 ( z 1 ) ∫ q 3 ⋯ ∫ q M q ^ 1 q 3 ⋯ q M [ log p θ ( x ( i ) , z ) ] q ^ 1 d q 2 ⋯ d q M ⋮ q ^ M ( z M ) = ∫ q ^ 1 ⋯ ∫ q ^ M − 1 q ^ 1 ⋯ q ^ M − 1 [ log p θ ( x ( i ) , z ) ] d q ^ 1 ⋯ d q ^ M − 1 (12.3.7) \color{red}\begin{array}{ll} \hat{q}_1(z_1) = \int_{q_2} \cdots \int_{q_{M}} q_2 \cdots q_M \left[ \log p_{\theta}(x^{(i)},z) \right]dq_2 \cdots dq_{M} \\ \hat{q}_2(z_2) = \int_{\hat{q}_1(z_1)}\int_{q_3} \cdots \int_{q_{M}} \hat{q}_1q_3 \cdots q_M \left[ \log p_{\theta}(x^{(i)},z) \right]\hat{q}_1dq_2 \cdots dq_{M} \\ \vdots \\ \hat{q}_M(z_M) = \int_{\hat{q}_1} \cdots \int_{\hat{q}_{M-1}} \hat{q}_1 \cdots \hat{q}_{M-1} \left[ \log p_{\theta}(x^{(i)},z) \right]d\hat{q}_1 \cdots d\hat{q}_{M-1}\end{array}\tag{12.3.7} q^1(z1)=∫q2⋯∫qMq2⋯qM[logpθ(x(i),z)]dq2⋯dqMq^2(z2)=∫q^1(z1)∫q3⋯∫qMq^1q3⋯qM[logpθ(x(i),z)]q^1dq2⋯dqM⋮q^M(zM)=∫q^1⋯∫q^M−1q^1⋯q^M−1[logpθ(x(i),z)]dq^1⋯dq^M−1(12.3.7)
如果将 q 1 , q 2 , ⋯ , q M {q}_1,{q}_2,\cdots,{q}_M q1,q2,⋯,qM看成一个个的坐标点,那么随着计算的深入,知道的坐标点越来越多,这实际上就是一种坐标上升的方法(Coordinate Ascend)。
这是一种迭代算法,那怎么考虑迭代的停止条件呢?设置当 L ( t + 1 ) ≤ L ( t ) \color{blue}\mathcal{L}^{(t+1)} \leq \mathcal{L}^{(t)} L(t+1)≤L(t)时停止迭代。
VI算法的整体步骤
针对平均场变分分布, 坐 标 上 升 近 似 推 断 算 法 ( C A V I ) \color{green}坐标上升近似推断算法(CAVI) 坐标上升近似推断算法(CAVI)是最常见的优化方法。CAVI交替地更新每个隐变量,更新时固定其他的隐变量的变分分布参数,用来计算当前隐变量 z j z_j zj的坐标上升公式。CAVI的算法步骤如下图所示。
Mean Field Theory(平均场理论)的存在问题
对于隐变量参数 z z z和数据集 x x x。
- z ⟶ x \color{red}z \longrightarrow x z⟶x是Generative Model,也就是 p ( x ∣ z ) p(x|z) p(x∣z)和 p ( x , z ) p(x,z) p(x,z),这个过程也被我们称为 D e c o d e r \color{red}Decoder Decoder。
- x ⟶ z \color{red}x \longrightarrow z x⟶z是Inference Model,表达关系是 p ( z ∣ x ) p(z|x) p(z∣x),这个过程被我们称为 E n c o d e r \color{red}Encoder Encoder。
本节先对SGVI参数规范,然后SGVI的梯度推导。
本节继上一节的内容,介绍Variance Reduction(方差缩减) 。