变分推理

背景介绍

这篇是一个阅读VI相关的资料([1],[2],[3],[4]) 的一个阅读笔记,首先根据[1]简单的介绍基本概念,然后根据[2],[3],[4]进行内容的丰富和增加。希望每天都可以更新一丢丢,渐渐搞清楚这个概念。

我们假设我们的观测变量为 $x=x_{1:n}$, $z=z_{1:m}$ 是隐变量以及额外的一些参数$\alpha$。假设我们的模型是联合概率分布 $p(x,z)$,我们的后验概率分布为:

\[p(z|x,\alpha) = \frac{p(z,x|\alpha)}{\int_z p(z,x|\alpha)}\]

后验概率将模型和数据结合了起来,它会用在很多下游的任务中。(计算后验也是变分推理能解决的一个实例)。但是,这里的分母很难计算,我们不难想象假设 $z$是离散变量,我们需要枚举所有的组合。举个例子,假设我们需要为一段$n$个字的文本进行标注,每个文字的标注为$z_i$,它的取值有 $K$ 个,那么这里我们的所有组合就有 $K^n$ 种。(虽然如果文字长度不长并且标注的取值很少的话还是可以穷举的)当 $n$ 变得很大的时候,这个问题就变得很难搞了(intractable)。如果不是离散的,很多模型对应的分母的积分都很难求解。因此,近似后验推理(approximate posterior inference) 就成了贝叶斯统计中一个很核心的问题。

它的基本思想是,为隐变量挑选一个分布族(a family of distributions),它的参数叫做变分参数(variational parameters), $q(z_{1:m}|v)$。接着我们需要找到一组参数使得隐变量概率分布 $q^*(z)$ 尽可能的接近我们需要求解的后验概率分布 $p(z|x)$。然后用 $q$ 来代替后验概率分布进行一些其他任务。(通常真正的后验概率分布其实不在这个挑选的变分分布族里面。)

Kullback-Leibler Divergence

KL Divergence 来自于信息论,它的定义如下:

\[\begin{align} KL(q||p) &= E_q \left[\log\frac{q(z)}{p(z|x)} \right] \quad(1) \\ \\ &= E_q[\log{q(z)}] - E_q[\log p(z|x)] \\ \\ &= E_q[\log q(z)] - E_q[\log p(x,z)] + \log p(x) \\ \\ &= \log p(x) - \{E_q[\log p(x,z)] - E_q[\log q(z)] \} \quad(2) \end{align}\]

接着,根据 Jensen 不等式,我们知道如果函数 $f$ 是凹的(concave),那么它满足:$f(E[X]) \ge E[f(X)]$ 这里,我们知道$\log$ 函数是凹的,所以有:

\[\begin{align} \log p(x) &= \log \int_z p(x, z) \\ \\ &= \log \int_z p(x, z) \frac {q(z)}{q(z)} \\ \\ &= \log \left( E_q \left[ \frac{p(x, z)}{q(z)} \right] \right) \\ \\ &\ge E_q[ \log p(x,z)] - E_q[\log q(z)] \end{align}\]

其中,$\log p(x)$ 被称为证据(evidence),右端被称为证据下界(evidence lower bound, ELBO)。 因此如果想减小KL散度使得两个分部尽可能接近,我们就需要增大证据下界ELBO。

Mean Field

平均场 (mean field)变分推理中,我们假设变分家族可以因式分解成:

\[q(z_1, \dots, z_m) = \prod_{j=1}^m q(z_j)\]

即每个变量都是独立的(条件独立于参数 $v_j$ )。(但很多时候隐变量其实是互相依赖的)

接着我们利用这个因式分解来优化证据下界 ELBO,可以用一些优化方法比如坐标上升,迭代的优化每个变分分布(保持其它的不变)。

变分推理步骤

在我们选择了我们的分布族之后,我们将其平均场的表达式带入到证据下界中:

\[\begin{align} \mathcal{L}(q) &= \int \prod_i q_i \left \{ \ln p(X,Z) - \sum_i \ln q_i \right \} \mathrm{dZ} \\ \\ &= \int q_j \left \{ \int \ln p(X,Z) \prod_{i \neq j} q_i \mathrm{dZ_i} \right \} \mathrm{dZ_j} - \int q_j \ln q_j \mathrm{dZ_j} + const \\ &= \int q_j \ln \widetilde{p}(X,Z_j) \mathrm{dZ_j} - \int q_j \ln q_j \mathrm{dZ_j} + const \end{align}\]

其中,

\[\ln \widetilde{p}(X,Z_j) \mathrm{dZ_j} = \mathbb{E}_{i \neq j} \left[ \ln p(X,Z) \right] + const\]

这里的 $\mathbb{E}_{i \neq j} \left[ \dots \right] $ 表示的对于所有的变量 $z_i (i \neq j)$ 在变分分布 $q$ 上的期望。

假设我们让 ${ q_{i \neq j}}$ 保持不变,并且我们需要最大化证据下界 $\mathcal{L}(q)$。不难看出 $\mathcal{L}(q)$ 其实就是一个负的在 $q_j{Z_j}$ 和 $\widetilde{p}(X,Z_j)$之间的KL散度,因此最大化下界就变成了最小化KL散度,而这个最小的情况显然发生在$q_j{Z_j} = \widetilde{p}(X,Z_j)$,这两个分布相同的情况下。(其实我们也可以直接对 $Z_j$求导)这样我们就获得了最优解 $$q^*_j(Z_j)$的表达式为:

\[\ln q^*_j(Z_j) = \mathbb{E}_{i \neq j} [\ln p(X,Z)] + const\]

这个式子表明因子 $\ln q_j $的最优解只需要考虑在所有隐变量和观测变量上考虑他们的联合分布,接着根据其它因子 ${q_i} (i \neq j)$对这个联合分布取期望。这里的常量可以理解为是正则项,归一化 $q^*_j(Z_j)$,对等号两边取exp,得到:

\[q^*_j(Z_j) = \frac{\mathbb{E}_{i \neq j}[\ln p(X,Z)]}{\int \exp(\mathbb{E}_{i \neq j} [\ln p(X,Z)]) \mathrm{dZ_j}}\]

在实际应用中,我们会初始化好所有的 $q_i(Z_i)$ 然后迭代更新每个因子。(收敛是可以保证的,因为下界关于每个因子是凸(convex)的, 注意是局部最优解)。另外,也有把上面的表达式写成这个样子的:

\[q^*(z_j) \propto \exp{ E_{-j}[\log p(z_j, Z_{-j}, x)]}\]

并且我们也可以表示成:

\[q^*(z_j) \propto \exp{ E_{-j}[\log p(z_j|Z_{-j},x)]}\]

因为分母不依赖 $z_j$ 所以加上 $\log$ 和期望运算后,也会变成常数。

因此,我们就得到了每个分量的最优分布。有没有觉得它和我们之前提到的吉布斯采样有异曲同工之妙。。

指数分布族

假设每个条件分布都满足下面的指数家族形式:

\[p(z_j|z_{-j},x) = h(z_j) \exp \{\eta(z_{-j},x)^\mathrm{ T }t(z_j) -a(\eta(z_{-j},x))\}\]

这里多提一句,我们在一些书本上更常看到的指数家族的形式是下面这样的,本质上是一样的哈: \(p(x|\theta) = h(x) \exp(\eta(\theta)T(x) -A(\theta))\)

$\theta$ 是参数。

这种形式条件分布可以表示很多复杂的模型 (暂时不翻译了)

  • Bayesian mixtures of exponential families with conjugate priors
  • Switching Kalman filters
  • Hierachical HMMs
  • Mixed-membership models of exponential families
  • Factorial mixtures/HMMs of exponential faimilies
  • Bayesian linear regression

如果我们要对这种类型的分布求它的近似分布,步骤如下:

  • 对条件分布取对数:
\[\log p(z_j | z_{-j},x) = \log h(z_j) + \eta(z_{-j},x)^\mathrm{ T }t(z_j) - a(\eta)(z_{-j},x))\]
  • 算相对于 $q(z_{-j})$ 的期望
\[E[\log p(z_j|z_{-j})] = \log h(z_j) + E[\eta(z_{-j},x)]^\mathrm{T}t(z_j) - E[a(\eta(z_{-j},x))]\]
  • 注意最后一项不依赖于$q_j$,因此:
\[q*(z_j) \propto h(z_j)\exp \{E[\eta(z_{-j},x)]^\mathrm{T} t(z_j) \}\]

正则项是 $a(E[\eta(z_{-j},x)])$

可以看到最优的 $q(z_j)$ 和条件分布 $p(z_j|z_{-j},x)$ 是相同的指数分布族。

坐标上升算法

针对前面的指数分布族形式的最优分布,给它的每个隐藏变量一个变分参数 $v_j$:

\[q(z_{1:m} | v) = \prod_{j=1}^m q(z_j|v_j)\]

迭代的用坐标上升法来设置每个参数$v_j$为给定了所有其他变量和观测序列后这个参数的期望值:

\[v^*_j = E[\eta(z_{-j},x)]\]

举个例子-单变量高斯

我们的目标是想要得到在给定了数据集$\mathcal{D} = {x_1, \dots, x_N}$ (在高斯分布中独立采样的样本), 均值 $\mu$ 和 精度 $\tau$ (precision, 方差的倒数)的后验分布 $p(\mu,\tau | \mathcal{D})$。

首先我们的似然函数是:

\[p(\mathcal{D}| \mu, \tau) = \left(\frac{\tau}{2\pi} \right)^\frac{N}{2} \exp \left \{ -\frac{\tau}{2} \sum_{n=1}^N(x_n-\mu)^2 \right \}\]

我们现在为 $\mu$ 和 $\tau$ 引入他们的共轭先验分布:

\[p(\mu | \tau) = \mathcal{N} \left(\mu | \mu_0, (\lambda_0 \tau_0)^{-1} \right)\] \[p(\tau) = Gam(\tau|a_0, b_0)\]

这两个先验分布一起组成了高斯-伽马共轭先验分布。(这个简单的问题是可以直接找到它的后验分布的,里只是作为简单的例子)。我们来看一下它的因式分解后的变分近似分布:

\[q(\mu, \tau) = q_\mu(\mu)q_\tau(\tau)\]

根据前面提到的表达式:

\[\ln q^*_j(Z_j) = \mathbb{E}_{i \neq j} [\ln p(X,Z)] + const\]

对于 $q_mu(\mu)$ 我们有:

\[\begin{align} \ln q^*_\mu(\mu) &= \mathbb{E}_\tau[\ln {p(\mathcal{D}| \mu, \tau)} + \ln{p(\mu|\tau)}] + const \\\\ &=-\frac{\mathbb{E[\tau]}}{2} \left\{\lambda_0(\mu-\mu_0)^2 + \sum_{n=1}^N(x_n-\mu)^2 \right\} + const \end{align}\]

注意,因为是只包含 $\mu$ 的式子,我们可以把不包含$\mu$的项,比如 $\ln p(\tau)$ 给去掉。 将上面的平方项展开,我们发现 $q_\mu(\mu)$ 的分布符合高斯分布 $\mathcal{N}(\mu |\mu_N, \lambda_N^{-1})$,其中均值和精度分别为:

\[\begin{align} \mu_N &=\frac{\lambda_0\mu_0 + N\bar{x}}{\lambda_0 + N} \\\\ \lambda_N &= (\lambda_0 + N) \mathbb{E[\tau]} \end{align}\]

同理,我们得到 $q_\tau(\tau)$ 的最优解:

\[\begin{align} \ln q^*_\tau(\tau) &= \mathbb{E_\mu} [\ln p(\mathcal{D}|\mu,\tau) + \ln p(\mu|\tau)] + \ln p(\tau) + const \\ &= (a_0-1)\ln \tau - b_0\tau + \frac{N}{2} \ln \tau -\frac{\tau}{2}\mathbb{E_\mu} \left[\sum_{n=1}^N(x_n-\mu)^2 + \lambda_0(\mu-\mu_0)^2 \right] + const \end{align}\]

因此 $q_\tau(\tau)$ 是一个伽马分布 $Gam(\tau | a_N, b_N)$, 其中:

\[\begin{align} a_N &= a_0 + \frac{N}{2} \\\\ b_N &= b_0 + \frac{1}{2} \mathbb{E_\mu}\left[\sum_{n=1}^N(x_n-\mu)^2 + \lambda_0(\mu-\mu_0)^2 \right] \end{align}\]

这里可以发现我们并没有指定$q_\mu(\mu)$ 和 $q_\tau(\tau)$ 的最优解的形式,只是恰好是得到了它们的分布分别为高斯和伽马分布 (因为共轭先验分布以及似然函数是高斯的原因,更多关于共轭先验分布详见[5],可以找到相应的常见共轭先验分布)

这样我们就得到了这两个参数的最优分布,它们是互相依赖的。一个找到最优解的方法是先随便给 $\mathbb{E[\tau]}$ 猜一个值,然后用这个值重新计算 $q_\mu(\mu)$。 给定了这个修正过的分布后,我们来计算 $\mathbb{E[\mu]}$, $\mathbb{E[\mu^2]}$,把这两个值代入 $q_\tau(\tau)$, 以此类推。

其它

通常,我们用上面的迭代方法就可以解出各个因式分解项的最优分布。对这个简单的例子,我们其实可以直接找到它的清晰解,把它们的分布代入彼此的等式。这里,我们假设 $\mu_0=a_0=b_0=\lambda_0=0$, 随着这么假设其实是不太对的,但是可以看到后验分布还是ok的。对于伽马分布,用标准结果$\mathbb{E[\tau]} = a_N/b_N$ 作为它的均值,那么结合前面的 $a_N$ 和 $b_N$ 表达式,我们得到:

$\frac{1}{\mathbb{E[\tau]}} = \mathbb{E} \left[ \frac{1}{N} \sum_{n=1}^N(x_n-\mu)^2\right] = \bar{x^2} - 2\bar{x}\mathbb{E[\mu]} + \mathbb{E[\mu^2]}$

接着用 $\mu_N$ 和 $\lambda_N$ 的表达式,我们得到了:

\[\mathbb{E[\mu]} = \bar{x}, \quad \mathbb{E}[\mu^2] = \bar{x}^2 + \frac{1}{N\mathbb{E}[\tau]}\]

将这两个项代入上面的表达式得到:

\[\begin{align} \frac{1}{\mathbb{E[\tau]}} &= \frac{N}{N-1}(\overline{x^2} - \bar{x}^2) \\ &= \frac{1}{N-1} \sum_{n=1}^N(x_n -\bar{x})^2 \end{align}\]

发现我们的精度的倒数,即方差,正好是无偏估计,从而避免了极大似然解上的偏差(bias)。

##

附录

Stochastic variational inference

实际问题中的数据量通常是很大的,但是坐标上升法通常需要遍历整个数据集,随着数据集的增加,每次迭代的代价就会增加。另一个方法是用基于梯度的方法,也就是随机变分推理(stochastic variational inference, SVI),它主要用来优化条件共轭模型的全局变分参数$\lambda$。算法流程见下图,图来自[2]。

SVI

指数分布族的小例子

略 [以后再补充]

References

[1] https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf

[2] https://arxiv.org/abs/1601.00670

[3] Christopher M Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics)[M]. Springer-Verlag New York, Inc. 2006.

[4] 李航. 统计学习方法第二版[J]. 2019.

[5] https://en.wikipedia.org/wiki/Conjugate_prior