Improved Variational Inference with Inverse Autoregressive Flow
July 23, 2024
Notes on: Improved Variational Inference with Inverse Autoregressive Flow
Inverse Autoregressive Flows scale well in high dimensional latent spaces For this method: Gaussian Autoregressive functions
Remark Variational Autoencoder (VAE):
\[\log p(x) \geq \langle \log p(x,z) - \log q(z \mid x) \rangle_{q(z \mid x)} = L(x,\theta) \\ L(x,\theta) =\log p(x)-\mathbb D_{KL}(q(z \mid x)\Vert p(z \mid x))\]Idea of context in latent variable: \(q(z_a,z_b \mid x)=q(z_a,x)q(z_b \mid z_a,x)\)
COMPUTATIONAL INTRACTABILITY
- Computationally efficient to compute and differentiate \(q(z \mid x)\)
- Computationally efficient to sample
Remark Normalizing Flows (NF):
Start with simple, computationally efficient distribution and apply invertible parametrized transformation \(f_t\)
\[z_0 \sim q(z_0 \mid x),z_t=f_t(z_{t-1},x) \forall t=1,...,T \\ \log q(z_T \mid x)=\log (z_0|x)- \sum^T_{t=1} \log \det \left \vert \frac{d z_t}{d z_{t-1}} \right \vert\]Originally: \(f(z_{t-1}=z_{t-1}+uh(w^\top z_{t-1} +b)\)
INVERSE AUTOREGRESSIVE TRANSFORMATION
Jacobian is lower triangular
Sampling process of \(y\): \(y_0=u_0+\sigma_0 \odot \epsilon_0 ;y_i=\mu_i(y_{1:i-1})+\sigma_i(y_{1:i-1})\odot\epsilon_i$ for $i>0\)
Inverse operation is defined: \(\epsilon_i=\frac{y_i-\mu_i (y_{1:i-1})}{\sigma_i(y_{1:i-1})}$ then $\log \det \left \vert \frac{d \epsilon}{d y} \right \vert = \sum_{i=1}^D -\log \sigma_i(y)\)
Architecture
- Perform deterministically bottom-up
- Sampling posterior top-down
Algorithm IAF
- Input: \(x\): data point, \(\theta\): model parameters, Encoder definition: \(f_\theta(x)\), Autoregressive definition \(g_\theta^t(z,h)\)
- :\([\mu,\sigma,h] \leftarrow f_\theta(x)\)
- :\(\epsilon \sim \mathcal N(0,I)\)
- :\(l \leftarrow - \left \Vert \log \sigma + \frac{1}{2} \epsilon^2 + \frac{1}{2} \log (2 \pi) \right \Vert_\infty\)
- for \(t=1:T\):
- :\([m,s]\leftarrow g_\theta^t(z,h)\)
- :\(\sigma \leftarrow \sigma(s)\)
- :\(z \leftarrow \sigma \odot z + (1-\sigma) \odot m\)
- :\(l \leftarrow l - \Vert \log \sigma \Vert_\infty\)