Problem scenario

image-20241117183928289

已知隐变量的先验分布和条件生成分布

以上背景下的相关问题有:

image-20241117184052537

Preliminary

evidence lower bound (variational lower bound)

推断(inference)可以理解为计算后验分布P(ZX)P(Z|X),

P(ZX)=P(X,Z)zP(X,Z=z)dzP(Z|X)=\frac{P(X,Z)}{\int_z{P(X,Z=z)}dz}

其中分母(规范项)很难计算,所以精确计算后验分布很困难,常常有两种方法求解近似的后验分布。

  • 采样法:例如MCMC,MCMC方法是利用马尔科夫链取样来近似后验概率,它的计算开销很大,且精度和样本有关系。
  • 变分法:使用一个简单的概率分布来近似后验分布,于是就转换为一个优化问题

KL divergence:

DKL(qp)=Exq[logpq]=xq(x)logp(x)q(x)D_{KL}(q||p)=E_{x\sim q}[\log{\frac{p}{q}}]=\sum_{x}q(x)\log\frac{p(x)}{q(x)}

KL divergence是衡量两个分布的距离的,它具有非负性,越小两个分布越接近。

变分法使用简单的概率分布q(z)q(z)来拟合后验分布p(zx)p(z|x)。例如q(z)q(z)可以选择来自高斯分布簇。所以推断问题就转化为优化问题:

λ=argminλDKL(q(z)p(zx))(1)\lambda^*=\arg \min_{\lambda}D_{KL}(q(z)||p(z|x)) \tag{1}

注意:这里λ\lambda为q的一个参数,若q源自高斯分布簇,则λ\lambda可以设为平均值和标准差

logp(x)=logp(x,z)q(z)logp(zx)q(z)Eq(z)[logp(x)]=Eq(z)[logp(x,z)q(z)]Eq(z)[logp(zx)q(z)]logp(x)=DKL(q(z)p(zx))+Eq(z)[logp(x,z)q(z)]\log p(x)=\log\frac{p(x,z)}{q(z)}-\log\frac{p(z|x)}{q(z)}\\ E_{q(z)}[\log p(x)]=E_{q(z)}[\log\frac{p(x,z)}{q(z)}]-E_{q(z)}[\log\frac{p(z|x)}{q(z)}]\\ \log p(x)=D_{KL}({q(z)||p(z|x)})+E_{q(z)}[\log\frac{p(x,z)}{q(z)}]

其中logp(x)\log p(x)叫做证据,,Eq(z)[logp(x,z)q(z)]E_{q(z)}[\log\frac{p(x,z)}{q(z)}]就叫做 evidence lower bound(ELBO),表示证据的下界。

则(1)式转化为:

λ=argmaxλELBO=argmaxλEq(z)[logp(x,z)q(z)](2)\lambda^*=\arg \max_{\lambda}ELBO=\arg \max_{\lambda}E_{q(z)}[\log\frac{p(x,z)}{q(z)}]\tag{2}

有了(2)式就可以利用一些优化技巧来求解得到λ\lambda^*

黑盒变分推断

例如黑盒变分推断,对ELBO进行求导:

λELBO=λEzq(z)[logp(x,z)logq(z)]=λq(z)[logp(x,z)logq(z)]dz=logp(x,z)λq(z)q(z)λlogq(z)logq(z)λq(z)dzλlogq(z)=λq(z)q(z)带入上式=q(z)logp(x,z)λlogq(z)λq(z)q(z)logq(z)λlogq(z)dz=q(z)logp(x,z)λlogq(z)q(z)logq(z)λlogq(z)dz=Ezq(z)[(logp(x,z)logq(z))λlogq(z)]\begin{aligned} \nabla_{\lambda}ELBO=&\nabla_{\lambda}E_{z\sim q(z)}[\log p(x,z)-\log q(z)]\\ =&\nabla_{\lambda}\int q(z) [\log p(x,z)-\log q(z)]dz\\ =&\int \log p(x,z)\nabla_{\lambda} q(z) -q(z)\nabla_{\lambda}\log q(z)-\log q(z)\nabla_{\lambda}q(z)dz\\ &\nabla_{\lambda}\log q(z)=\frac{\nabla_{\lambda}q(z)}{q(z)}\text{带入上式}\\ =&\int q(z)\log p(x,z)\nabla_{\lambda} \log q(z) -\nabla_{\lambda}q(z)-q(z)\log q(z)\nabla_{\lambda}\log q(z)dz\\ =&\int q(z)\log p(x,z)\nabla_{\lambda} \log q(z) -q(z)\log q(z)\nabla_{\lambda}\log q(z)dz\\ =&E_{z\sim q(z)}[ (\log p(x,z)-\log q(z))\nabla_{\lambda} \log q(z) ]\\ \end{aligned}

即有

λELBO=Ezq(z)[(logp(x,z)logq(z))λlogq(z)](3)\nabla_{\lambda}ELBO=E_{z\sim q(z)}[ (\log p(x,z)-\log q(z))\nabla_{\lambda} \log q(z) ]\tag3

使用样本统计的话,ziz_iq(z)q(z)中抽样的,于是上式变为:

λELBO=1Ni=1N[(logp(x,zi)logq(zi))λlogq(zi)]\nabla_{\lambda}ELBO=\frac{1}{N}\sum_{i=1}^{N}[ (\log p(x,z_i)-\log q(z_i))\nabla_{\lambda} \log q(z_i) ]

EM 算法(Expectation-Maximum)传统贝叶斯推断

考虑离散情况下,我们需要求某个分布pp的参数θ\theta,使用ML(Maximum Likelihood)方式求解参数:

θ=argmaxθiNlogp(xi)\theta^*=\arg\max_{\theta} \sum_i^N\log p(x_i)

l(θ)=ilogp(xiθ)=ilogzip(xi,zi;θ)=ilogziQi(zi)p(xi,zi;θ)Qi(zi)iziQi(zi)logp(xi,zi;θ)Qi(zi)\begin{aligned} l(\theta)=\sum_i\log p(x_i\mid\theta)& =\sum_i\log\sum_{z_i}p(x_i,z_i;\theta) \\ &=\sum_i\log\sum_{z_i}Q_i(z_i)\frac{p(x_i,z_i;\theta)}{Q_i(z_i)} \\ &\geq\sum_i\sum_{z_i}Q_i(z_i)\log\frac{p(x_i,z_i;\theta)}{Q_i(z_i)} \end{aligned}

而似然函数可以写为:

logp(x)=DKL(q(z)p(zx))+Eq(z)[logp(x,z)q(z)]\log p(x)=D_{KL}({q(z)||p(z|x)})+E_{q(z)}[\log\frac{p(x,z)}{q(z)}]

Random initialization θ\theta repeat until convergence:

(E-step) For each i, set Qi(zi)=p(zixi;θ)(M-step) Set θ=argmaxθQi(zi)=iziQi(zi)logp(xi,zi;θ)Qi(zi)\begin{align*} &\text{(E-step) For each } i, \text{ set } Q_i(z_i) = p(z_i \mid x_i; \theta) \\ &\text{(M-step) Set } \theta = \arg\max_{\theta} Q_i(z_i) = \sum_i \sum_{z_i} Q_i(z_i) \log \frac{p(x_i, z_i; \theta)}{Q_i(z_i)} \end{align*}

E步骤用真实后验分布 for a choice of θ\theta,(真实后验可以使用贝叶斯公式求得)

p(zixi;θ)=p(xizi;θ)p(zi;θ)ip(xizi;θ)p(zi;θ)p(z_i \mid x_i; \theta) = \frac{p(x_i \mid z_i; \theta) p(z_i; \theta)}{\sum _i p(x_i\mid z_i; \theta)p(z_i;\theta)}

M步骤计算最优θ\theta for a choice of QiQ_i

EM算法同时优化θ\theta和后验分布,自然可以回答Problem scenario 中的三个问题,但是该算法在后验推断中的分母项存在很难的问题。

VAE这篇文章的motivation 是:

  1. Intractability: the case where the integral of the marginal likelihood pθ(x)=pθ(z)pθ(xz)dzp_\theta(x) = \int p_\theta(z) p_\theta(x|z) dz is intractable (so we cannot evaluate or differentiate the marginal likelihood), where the true posterior density pθ(zx)=pθ(xz)pθ(z)pθ(x)p_\theta(z|x) = \frac{p_\theta(x|z)p_\theta(z)}{p_\theta(x)} is intractable (so the EM algorithm cannot be used), and where the required integrals for any reasonable mean-field VB algorithm are also intractable. These intractabilities are quite common and appear in cases of moderately complicated likelihood functions pθ(xz)p_\theta(x|z), e.g. a neural network with a nonlinear hidden layer.

  2. batch optimization is too costly

overall

pθ(z)p_{\theta}(z):true prior distribution ,分布族已知(θ\theta未确定) or 分布已知

pθ(xz)p_{\theta}(x|z):probabilistic decoder ,generative model,其分布族已知

qϕ(zx)q_{\phi}(z|x):probabilistic encoder,recognition model,an approximation to the **intractable true posterior ** pθ(zx)p_{\theta}(z|x)

3 task:

  1. 使用ML or MAP 求解 θ\theta
  2. pθ(zx)p_{\theta}(z|x) for a given θ\theta
  3. marginal likelihood pθ(x)p_{\theta}(x)

A method for learning the recognition model parameters ϕ\phi jointly with the generative model parameters θ\theta

image-20241122211548230

Objective function

The marginal likelihood is composed of a sum over the marginal likelihoods of individual datapoints logpθ(x(1),,x(N))=i=1Nlogpθ(x(i))\log p_{\theta}\left(x^{(1)},\cdots, x^{(N)}\right)=\sum_{i=1}^{N}\log p_{\theta}\left(x^{(i)}\right), which can each be rewritten as:

logpθ(x(i))=DKL(qϕ(zx(i))pθ(zx(i)))+L(θ,ϕ;x(i))(4)\log p_{\theta}\left(x^{(i)}\right)=D_{KL}\left(q_{\phi}\left(z\mid x^{(i)}\right)\mid\mid p_{\theta}\left(z\mid x^{(i)}\right)\right)+\mathcal{L}\left(\theta,\phi; x^{(i)}\right)\tag{4}

The first RHS term is the KL divergence of the approximate from the true posterior. Since the KL-divergence is non-negative, the second RHS term L(θ,ϕ;x(i))\mathcal{L}\left(\theta,\phi; x^{(i)}\right) is called the (variational) lower bound on the marginal likelihood of datapoint ii, and can be written as:

logpθ(x(i))L(θ,ϕ;x(i))=Eqϕ(zx(i))[logqϕ(zx(i))+logpθ(x(i),z)]\log p_{\theta}\left(x^{(i)}\right)\geq\mathcal{L}\left(\theta,\phi;x^{(i)}\right)=E_{q_{\phi}(z\mid x^{(i)})}\left[-\log q_{\phi}(z\mid x^{(i)})+\log p_{\theta}(x^{(i)},z)\right]

which can also be written as:

L(θ,ϕ;x(i))=DKL(qϕ(zx(i))pθ(z))+Eqϕ(zx(i))[logpθ(x(i)z)](5)\mathcal{L}\left(\theta,\phi; x^{(i)}\right)=-D_{KL}\left(q_{\phi}\left(z\mid x^{(i)}\right)\mid\mid p_{\theta}(z)\right)+E_{q_{\phi}\left(z\mid x^{(i)}\right)}\left[\log p_{\theta}\left(x^{(i)}\mid z\right)\right] \tag{5}

最大化L\mathcal{L},可以使得DKLD_{KL}尽可能小,同时使得logpθ(x(i))\log p_{\theta}(x^{(i)})尽可能大,所以任务描述为:

(θ,ϕ)=argmaxi[N]L(θ,ϕ,x(i))(\theta^*,\phi^*)=\arg \max \sum_{i\in [N]}\mathcal{L}(\theta,\phi,x^{(i)})

关于(5)的理解,第一部分KL divergence 是一个正则项,使得zz的后验分布和其先验分布相似,第二部分是一个似然函数(i.e. 交叉熵)。

例如:若取pθ(xz)=1(2π)nσθexμθ22p_{\theta}(x|z)=\frac{1}{(\sqrt{2\pi})^n \sigma_{\theta}} e^{-\frac{||x-\mu_{\theta}||^2}{2}},即高维正态分布,其中μθ=Expθ(xZ=z)[x],σθ2=Varxpθ(xZ=z)[x]\mu_{\theta}=\mathbb{E}_{x\sim p_{\theta^*}(x|Z=z)}[x],\sigma_{\theta}^2=\mathbb{Var}_{x\sim p_{\theta^*}(x|Z=z)}[x]

则第二部分变为Eqϕ(zx(i))[log((2π)nσθ)2x(i)μθ2]E_{q_{\phi}(z|x^{(i)})}[\frac{\log \left( (\sqrt{2\pi})^n\sigma_{\theta} \right)}{2}||x^{(i)}-\mu_{\theta}||^2]

Eqϕ(zx(i))[logpθ(x(i)z)]=Eqϕ(zx(i))[log((2π)nσθ)2x(i)μθ2]=l[L]log((2π)nσθl)2x(i)μθl2μθl=x^l,σθl fixed for l=l[L]Cx(i)x^l2\begin{aligned} E_{q_{\phi}\left(z\mid x^{(i)}\right)}\left[\log p_{\theta}\left(x^{(i)}\mid z\right)\right]=&E_{q_{\phi}(z|x^{(i)})}[\frac{\log \left( (\sqrt{2\pi})^n\sigma_{\theta} \right)}{2}||x^{(i)}-\mu_{\theta}||^2]\\ =&\sum_{l\in [L]}\frac{\log \left( (\sqrt{2\pi})^n\sigma_{\theta}^{l} \right)}{2}||x^{(i)}-\mu_{\theta}^{l}||^2\\ &\mu_\theta^{l}=\hat{x}^{l},\sigma_{\theta}^l \text{ fixed for } \forall l\\ =&\sum_{l\in [L]}C||x^{(i)}-\hat{x}^l||^2 \end{aligned}

The SGVB estimator

我们需要对θ,ϕL(θ,ϕ;x(i))\nabla_{\theta,\phi}\mathcal{L}(\theta,\phi;x^{(i)})做一个估计。文章中认为The usual (naive) Monte Carlo gradient estimator 的方差太大:ϕEqϕ(z)[f(z)]=Eqϕ(z)[f(z)qϕ(z)logqϕ(z)]\nabla_{\phi} \mathbb{E}_{q_{\phi}(\mathbf{z})} \left[ f(\mathbf{z}) \right] = \mathbb{E}_{q_{\phi}(\mathbf{z})} \left[ f(\mathbf{z}) \nabla_{q_{\phi}(\mathbf{z})} \log q_{\phi}(\mathbf{z}) \right],这里假设了ff和参数ϕ\phi无关。

generic Stochastic Gradient Variational Bayes estimator

即重参数化的 Monte Carlo gradient estimator

qϕ(zx)q_{\phi}(z|x)中存在两个问题

  1. for coding , probabilistic encoder 如何使用 gradient descent
  2. 是否可以给qϕ(zx)q_{\phi}(z|x)定义一个先验分布簇

关于问题2,(5)式中KL divergence 告诉我们 qϕ(zx)q_{\phi}(z|x) 应该和pθ(z)p_{\theta}(z)相似,自然可以选择pθp_{\theta}的已知分布族。

for a chosen approximate posterior qϕ(zx)q_{\phi}(z\mid x),we can reparameterize the random variable z~qϕ(zx)\widetilde{z}\sim q_{\phi}(z\mid x) using a differentiable transformation gϕ(ϵ,x)g_{\phi}(\epsilon, x) of an (auxiliary) noise variable ϵ\epsilon:

z~=gϕ(ϵ,x) with ϵp(ϵ)\widetilde{z}=g_{\phi}(\epsilon, x) \text{ with } \epsilon \sim p(\epsilon)

We can now form Monte Carlo estimates of expectations of some function f(z)f(z) w.r.t. qϕ(zx)q_{\phi}(z\mid x) as follows:

Eqϕ(zx(i))[f(z)]=Ep(ϵ)[f(gϕ(ϵ,x(i)))]1Ll=1Lf(gϕ(ϵ(l),x(i))) where ϵ(l)p(ϵ)\mathbb{E}_{q_{\phi}(z\mid x^{(i)})}\left[f(z)\right] = \mathbb{E}_{p(\epsilon)}\left[f\left(g_{\phi}\left(\epsilon, x^{(i)}\right)\right)\right] \simeq \frac{1}{L}\sum_{l=1}^{L} f\left(g_{\phi}\left(\epsilon^{(l)}, x^{(i)}\right)\right) \text{ where } \epsilon^{(l)} \sim p(\epsilon)

We apply this technique to the variational lower bound , yielding our generic Stochastic Gradient Variational Bayes (SGVB) estimator L~A(θ,ϕ;x(i))L(θ,ϕ;x(i))\widetilde{\mathcal{L}}^{A}(\theta,\phi; x^{(i)}) \simeq \mathcal{L}(\theta,\phi; x^{(i)}):

L~A(θ,ϕ;x(i))=1Ll=1Llogpθ(x(i),z(i,l))logqϕ(z(i,l)x(i))wherez(i,l)=gϕ(ϵ(i,l),x(i))andϵ(l)p(ϵ)(6)\begin{align*} \widetilde{\mathcal{L}}^{A}(\theta,\phi; x^{(i)}) &= \frac{1}{L}\sum_{l=1}^{L} \log p_{\theta}(x^{(i)}, z^{(i, l)}) - \log q_{\phi}(z^{(i, l)} \mid x^{(i)}) \\ \text{where} \quad z^{(i, l)} &= g_{\phi}(\epsilon^{(i, l)}, x^{(i)}) \quad \text{and} \quad \epsilon^{(l)} \sim p(\epsilon) \end{align*} \tag{6}

不考虑pθp_{\theta}部分,看看上式是否可微(是否可以使用梯度下降),换句话说,如果给定x(i),ϕ,ϵ(l)x^{(i)},\phi,\epsilon ^{(l)}后,z(i,l)z^{(i,l)}唯一确定,这样就可以使用梯度下降来求导:

可以选择qϕ(zx)q_{\phi}(z|x)encoder为deterministic (例如使用 MLE) ,并且gϕg_{\phi}也使用deterministic (因为是映射),zz的随机性转移到ϵ\epsilon上面。

second version of the SGVB estimator

KL-divergence DKL(qϕ(zx(i))pθ(z))D_{KL}\left(q_{\phi}\left(z\mid x^{(i)}\right)\mid\mid p_{\theta}(z)\right) of eq.(5) can be integrated analytically:

因为pθ(z)p_{\theta}(z)的分布簇或者分布已知,qϕ(zx)q_{\phi}(z|x) 的分布簇和前者一致,例如给出pθ(z)=N(z;0,I)p_{\theta}(z)=\mathcal{N}(z;\mathbf{0},\mathbf{I})q(zx)=N(z;μ,σ2)q(z|x)=\mathcal{N}(z;\mathbf{\mu},\mathbf{\sigma}^2)JJ表示zz的维度:

DKL((qϕ(z)pθ(z)))=qθ(z)(logpθ(z)logqθ(z))dz=12j=1J(1+log((σj)2)(μj)2(σj)2)-D_{KL}((q_{\phi}(\mathbf{z}) \| p_{\theta}(\mathbf{z}))) = \int q_{\theta}(\mathbf{z}) \left( \log p_{\theta}(\mathbf{z}) - \log q_{\theta}(\mathbf{z}) \right) d\mathbf{z} = \frac{1}{2} \sum_{j=1}^{J} \left( 1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2 \right)

The above equation can be regarded as the regularization of ϕ\phi

Such that only the expected reconstruction error Eqϕ(zx(i))[logpθ(x(i)z)]\mathbb{E}_{q_{\phi}\left(z\mid x^{(i)}\right)}\left[\log p_{\theta}\left(x^{(i)}\mid z\right)\right] requires estimation by sampling.

This yields a second version of the SGVB estimator L~B(θ,ϕ;x(i))L(θ,ϕ;x(i))\widetilde{\mathcal{L}}^{B}\left(\theta,\phi; x^{(i)}\right)\simeq\mathcal{L}\left(\theta,\phi; x^{(i)}\right), which typically has less variance than the generic estimator:

L~B(θ,ϕ;x(i))=DKL(qϕ(zx(i))pθ(z))+1Ll=1L(logpθ(x(i)z(i,l)))wherez(i,l)=gϕ(ϵ(i,l),x(i)) andϵ(l)p(ϵ)(7)\begin{array}{l} \widetilde{\mathcal{L}}^{B}\left(\theta,\phi; x^{(i)}\right)=-D_{KL}\left(q_{\phi}\left(z\mid x^{(i)}\right)\mid\mid p_{\theta}(z)\right)+\frac{1}{L}\sum_{l=1}^{L}\left(\log p_{\theta}\left(x^{(i)}\mid z^{(i, l)}\right)\right)\\ \text{where}\quad z^{(i, l)}=g_{\phi}\left(\epsilon^{(i, l)}, x^{(i)}\right)\text{ and}\epsilon^{(l)}\sim p(\epsilon) \end{array}\tag{7}

Given multiple datapoints from a dataset XX with NN datapoints, we can construct an estimator of the marginal likelihood lower bound of the full dataset, based on minibatches:

L(θ,ϕ;X)L~M(θ,ϕ;XM)=NMi=1ML~(θ,ϕ;x(i))(8)\mathcal{L}(\theta,\phi; X)\simeq\widetilde{\mathcal{L}}^{M}\left(\theta,\phi; X^{M}\right)=\frac{N}{M}\sum_{i=1}^{M}\widetilde{\mathcal{L}}\left(\theta,\phi; x^{(i)}\right)\tag{8}

where the minibatch XM={x(i)}i=1MX^{M}=\left\{x^{(i)}\right\}_{i=1}^{M} is a randomly drawn sample of MM datapoints from the full dataset XX with NN datapoints.

AEVB algorithm

Algorithm 1 Minibatch version of the Auto-Encoding VB (AEVB) algorithm. Either of the two SGVB estimators can be used. We use settings M=100M=100 and L=1L=1 in experiments.

  1. θ,ϕ\theta, \phi \leftarrow Initialize parameters
  2. repeat
    1. XMX^{M} \leftarrow Random minibatch of MM datapoints (drawn from full dataset)
    2. ϵ\epsilon \leftarrow Random samples from noise distribution p(ϵ)p(\epsilon)
    3. gθ,ϕL~M(θ,ϕ;XM,ϵ)g \leftarrow \nabla_{\theta,\phi}\widetilde{\mathcal{L}}^{M}(\theta,\phi; X^{M},\epsilon) (Gradients of minibatch estimator (8) )
    4. θ,ϕ\theta, \phi \leftarrow Update parameters using gradients gg (e.g. SGD or Adagrad )
  3. until convergence of parameters (θ,ϕ)(\theta, \phi)
  4. return θ,ϕ\theta, \phi

review 3 tasks

  1. 使用ML or MAP 求解 θ\theta
  2. pθ(zx)p_{\theta}(z|x) for a given θ\theta
  3. marginal likelihood pθ(x)p_{\theta}(x)

encoder qϕ(zx)q_{\phi}(z|x) 是关于Task 2 的近似

关于task 1

求解ML问题,极大似然参数估计: (4) 式中 左边 就是 ML的对数似然函数,所以最大化(variational) lower bound 就是求解ML问题;

求解MAP问题,贝叶斯参数估计(后验参数估计):MAP对数似然函数为i=1Nlogpθ(x(i))+logδ(θ)\sum_{i=1}^{N}\log p_{\theta}(x^{(i)})+\log \delta(\theta)δ(θ)\delta(\theta)表示对于θ\theta的先验判断,我们可以在objective function 中添加一个先验项,即可求解MAP问题

关于 task 3,paper附录中 使用了 类MCMC技巧求得的隐变量后验分布+求AEVB求得的pθ(x(i)zl)p_{\theta}(x^{(i)}|z^{l})来估计pθ(x(i))p_{\theta}(x^{(i)})