In this post we show how this regularization term is derived from the probabilistic perspective. Enjoy!
Recall these assumptions that we have used to derive the least-squared regression:
In overfitting, our model fit the noise or random error instead of the relationship between descriptor and target value. When it happens, the model parameters or weights are usually excessively complex. The weights could be represented as a vector in vector space which has a very high norm. In other words, the value of each elements is very large.
To avoid this to happen, we shall introduce another assumption. We could add prior assumption that model parameters \(w\) distributes according to a multivariate gaussian of mean 0 and covariance matrix \(\Sigma\). With this prior we could keep the values of \(w\) to be small or sufficiently close to 0.
We could in fact take this assumption even further. We suppose that variables of this multivariate normal are independent to each other. Thus, the covariance matrix decomposed as an identity matrix multiplied by a scalar \(\alpha\), namely, \(\Sigma = \alpha*\mathcal{I}\). We can think of \(\alpha\) as the constant that controls the width of the gaussian curve.
We have previously shown that minimizing \(\mathcal{E}(w)\) is equivalent to maximizing the likelihood of \(w\) or also the same of maximizing the probability of generating the data given model parameters, \(p(\mathcal{D}|w)\). By conditional probability rule, we can express the following equalities:
\[\begin{aligned} p(w|\mathcal{D})*p(\mathcal{D}) & = p(\mathcal{D}|w)*p(w)\newline p(w|\mathcal{D}) & = \frac{\overbrace{p(\mathcal{D}|w)}^{\text{likelihood term}}*p(w)}{p(\mathcal{D})} \end{aligned}\]From above derivation, we can see that maximizing the likelihood term, we are also maximizing the posterior term on the left hand side of the equation. One should question: why should we maximize the posterior instead of the likelihood?
Notice that by explicitly maximizing the posterior, we have the advantage of incorporating the prior \(p(w)\) to our estimation. This way, our assumption that \(e_i \sim \mathcal{N}\big(0, \sigma^2 \big)\), could be put to use. Lastly, since the denominator \(p(\mathcal{D})\) only acts as a constant in this maximization, we can simply ignore it.
Now, we recast our goal from finding \(w\) that maximize the likelihood, to finding \(w\) that maximize the posterior. We refer this goal as a maximum a-posterior estimation of \(w\), which could be denoted as \(w_{\text{MAP}}\).
Having all this intuition cleared, we can now start deriving our L2 loss function! See how \(w_{\text{MAP}}\) could be further decomposed as follow:
\[\begin{aligned} w_{\text{MAP}} & = \operatorname*{argmax}_{w} p(\mathcal{D}|w)*p(w) \newline & = \operatorname*{argmax}_{w} \prod_i^n p(y_{i}|x_{i},w) * p(w|aI,\mu = 0) \newline & \text{by independence rule:}\newline & = \operatorname*{argmax}_{w} \prod_i^n \frac{1}{\sigma \sqrt{2\pi}} exp \big\{ -\frac{(f_w (x_i)-y_i)^2}{2\sigma^2} \big\} * \frac{1}{(2\pi)^{\frac{d}{2}} \mid\alpha I \mid^{\frac{1}{2}}} exp\big\{ -\frac{1}{2}(w-\mu)^T (\alpha I)^{-1} (w-\mu) \big\}\newline & \text{simplified further by the fact that } \mu = 0\newline & = \operatorname*{argmax}_{w} \prod_i^n \frac{1}{\sigma \sqrt{2\pi}} exp \big\{ -\frac{(f_w (x_i)-y_i)^2}{2\sigma^2} \big\} * \frac{1}{(2\pi)^{\frac{d}{2}} \mid\alpha I \mid^{\frac{1}{2}}} exp\big\{ -\frac{1}{2}w^T (\alpha I)^{-1} w \big\}\newline & \textbf{note: }d\text{ is the dimension of the weight vector } w \end{aligned}\]At this point, we can use the following useful properties of identity matrix:
As we already know, by finding the MAP is equivalent to minimizing the negative log of MAP. Therefore, we can restate \(w_{\text{MAP}}\) to be as follow:
\[\begin{aligned} w_{\text{MAP}} & = \operatorname*{argmin}_{w} \sum_i^n \log \frac{1}{\sigma \sqrt{2\pi}} + \frac{(f_w (x_i)-y_i)^2}{2\sigma^2} + \log \frac{1}{(2\pi\alpha)^{\frac{d}{2}}} + \frac{1}{2\alpha} w^T w \newline & \text{... the two terms with log are constants that can be ignored:}\newline & = \operatorname*{argmin}_{w} \sum_i^n \frac{(f_w (x_i)-y_i)^2}{2\sigma^2} + \frac{1}{2\alpha} w^T w \end{aligned}\]Note: We cannot ignore and remove the remaining denominators since it could translate the original function. The solution to the original function would be different to the resulting function.
We can additionally perform simple mathematical manipulation by multiplying both terms by a constant \(\sigma^2\): \(\operatorname*{argmin}_{w} \sum_i^n \frac{1}{2} (f_w (x_i)-y_i)^2 + \frac{\sigma^2}{2\alpha} w^T w\).
Finally, we can introduce a constant \(\lambda\), where \(\lambda = \frac{\sigma^2}{\alpha}\), and take the constant \(\frac{1}{2}\) to the outside of the summation, then voila! we arrived at the equation we show in the beginning of this post:
\[\begin{aligned} \mathcal{E}(w) = \frac{1}{2} \sum^n_{i=1} \Big( f_w(x_i) - y_i \Big)^2 + \lambda w^T w \end{aligned}\]We have seen how the L2 regularization term is derived, and what the magic scalar \(\lambda\) is. We saw that \(\sigma^2\) controls how fat the tail of the error distribution, while \(\alpha\) controls the shape (or width) of the gaussian prior distribution for weights \(w\). Therefore, \(\lambda\) implicitly controls both quantity. Namely, as \(\lambda\) gets larger, we allow the model to tolerate more error (by having fatter tail of error distribution), and at the same time, we narrow down the weights \(w\) distribution closer to 0. This way, we avoid having the weights that are too large, and also avoid outliers to penalize the loss function too much.
That is all for now. I hope you enjoy it and see you in the next post.
You should be pretty familiar with mean-squared-error function which linear regression setup tries to minimize. Namely: \[ \begin{aligned} \mathcal{E}(w) = \frac{1}{2} \sum^n_{i=1} \big( f_w(x_i) - y_i \big)^2 \end{aligned} \]
Where \(y_i\) is the true target value and \(f_w (x_i)\) is a linear function that parameterized by the weights \(w\) we estimated from the data. Namely, \(f_w (x_i) = w * \phi (x)\), where \(\phi (x)\) is the feature map \(\phi (x): \mathcal{R}^D \rightarrow \mathcal{R}^m\) (if you’re not familiar with feature map, then simply regard \(\phi (x)\) just as regular \(x\)).
How one could justify this particular kind of error function?
We will start this argument by an assumption that our data \({(x_1,y_2), (x_2,y_2), ..., (x_i,y_i)}\) are generated by some unknown true target function \(f_{\hat{w}}\) where \(\hat{w} \in \mathcal{R}^m\), and some random noise \(e_i\) for each data point \(x_i\). We then can formulate each target value \(y_i\) as:
\[ \begin{aligned} y_i & = f_{\hat{w}} + e_i\newline & = \hat{w} * \phi (x_i) + e_i \end{aligned} \]
The second assumption is that each \(e_i\) is i.i.d. and that it distributes according to some normal/gaussian distribution of certain variance and centered at 0:
\[ e_i \sim \mathcal{N}\big(0, \sigma^2 \big) \]
By this assumption, we now have the true target values \(y_i\) also distributes by a normal distribution:
\[ y_i \sim \mathcal{N}\big(f_{\hat{w}}, \sigma^2 \big) \]
Note: \(\sigma \in (0, \infty)\) and this sigma is shared among all \(i = 1, 2, \dots, n\).
One fundamental notion in statistics and machine learning that we will introduce here is likelihood function and maximum likelihood estimate. Likelihood function is the function of parameters \(w\), that measure the likelihood of the parameters given data \(D\). Turns out, the likelihood is equals to the probability that data \(D\) is generated given parameters \(w\). Formally, likelihood function \(\mathcal{L}: \mathcal{R}^m \rightarrow \mathcal{R}\), maps parameter (vector) to a real number. In our case of linear regression, the likelihood function of \(w\) is defined as follow:
\[ \begin{aligned} \mathcal{L}(w \mid D) & = P(D \mid w)\newline & = P(y_1,y_2,\dots,y_n \mid x_1,x_2,\dots,x_n, w) \end{aligned} \]
Note: in regression problem, \(x\) are the observed variable and therefore we have \(x\) as the condition in addition to \(w\).
The next important notion would be the maximum likelihood estimation. It simply defined as finding the set of parameters which gives the maximum likelihood given the data. We can denote the maximum likelihood estimate of \(w\) as follow:
\[\begin{aligned} \hat{w} & = \operatorname*{argmax}_{w \in \mathcal{R}^m} \mathcal{L}(w \mid D)\newline & = \operatorname*{argmax}_{w \in \mathcal{R}^m} P(D \mid w)\newline & = \operatorname*{argmax}_{w \in \mathcal{R}^m} P(y_1,y_2,\dots,y_n \mid x_1,x_2,\dots,x_n, w)\newline & = \operatorname*{argmax}_{w} \prod_{i=1}^n P(y_i \mid x_i, w) & \text{(by conditional independence of $y_i$)} \end{aligned}\]Now, recall our assumption that: \(y_i \sim \mathcal{N}\big(f_{\hat{w}}, \sigma^2 \big)\). By this assumption, we can plugin a gaussian distribution formula into our maximum likelihood estimate equation. Namely:
\[\begin{aligned} \hat{w} & = \operatorname*{argmax}_{w} \prod_{i=1}^n \frac{1}{\sqrt{2 \pi \sigma^2}} e^{- \frac{( y_i - f_w (x_i))^2}{2 \sigma^2}} \end{aligned}\]Maximizing the above expression is equivalent to minimizing the negative logarithm of it. We do so to remove the exponential and make it more convenient to deal with this equation. We should now have:
\[\begin{aligned} \hat{w} & = \operatorname*{argmin}_{w} \sum_{i=1}^n \frac{( y_i - f_w (x_i))^2}{2 \sigma^2} - \log \frac{1}{\sqrt{2 \pi \sigma^2}} \end{aligned}\]Because we minimizing with respect to \(w\), the terms \(2\sigma^2\) and log on the right becomes a constant that we can ignore. We finally multiply the expression with a constant \(\frac{1}{2}\) to make it more convenient when taking the derivative of it.
Finally, we ended up with this expression that we need to minimize:
\[\begin{aligned} \hat{w} & = \operatorname*{argmin}_{w} \sum_{i=1}^n \frac{1}{2} (y_i - f_w (x_i))^2 \end{aligned}\]Voila! There we have the loss function that we really familiar with. And so, the take on lesson from this derivation is that finding the maximum likelihood estimate of \(w\) is equivalent to minimizing the the sum-of-squares error.
You should also notice that we do use several assumptions in this derivation. In practice, these assumptions don’t always hold true. For example, the assumption of gaussian error distribution makes it very sensitive to outliers. In the next post, we shall show you why that is the case, and how regularization term could be justified to improve the model robustness to outliers.
See you in the next post.
We have so far done quite some work in deriving variational inference. We arrived at the point where we split the large equation of \(\mathcal{L}\) into 3 parts, and simplify it again further. In this section, we will do even more tricks and mathematical manipulation. Hang on, we’re almost there!
The first thing we do is splitting this constant \(\mathcal{C}\) into another two negative constants: \(-\mathcal{C}_1\) and \(-\mathcal{C}_2\). Our equation would now looks like this:
\[ \begin{aligned} \mathcal{L} = \sum_{z_1} q(z_1) \Bigg[ E_{z_2,z_3} \Big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) \Big] + \mathcal{C}_1 + \mathcal{C}_2 \Bigg] - \sum q(z_1) \ln q(z_1) \end{aligned} \]
Lets shift our focus into this part: \(E_{z_2,z_3} \Big[ \ln p(X, Z)\Big] + \mathcal{C}_1\). We will try to argue that this part corresponds to a log of some function \(\ln f(X,Z)\). The goal here is so that we can subtitute that part in the equation with \(\ln f(X,Z)\).
Follow carefully our following arguments: we know that \(E_{z_2,z_3} \Big[ \ln p(X, Z)\Big]\) is the expectation of a log of some function. We also know that \(\mathcal{C}_1\) is some constant. It is obvious that log of a function added by a constant is also a log of some function. But notice that \(f(X,Z)\) is not just a regular function, it is a probability function! It means that it has a property of summing up to 1. How do we know that?
Well, we could first get rid of the log in front of \(f(X,Z)\). This could be done by taking the exponential of both side of the equation:
\[ \begin{aligned} \ln f(X,Z) & = E_{z_2,z_3} \Big[ \ln p(X, Z)\Big] + \mathcal{C}_1\newline e^{\big(\ln f(X,Z)\big)} & = e^{\big( (Ez_2,z_3 \Big[ \ln p(X, Z)\Big] + \mathcal{C}_1\big)} \newline f(X,Z) & = e^{\big( \mathcal{C}_1 \big)} e^{\big( (Ez_2,z_3 \Big[ \ln p(X, Z)\Big] \big)} & \text{(by exponential rule)}\newline f(X,Z) & = \mathcal{K}*e^{\big(Ez_2,z_3 \Big[ \ln p(X, Z)\Big] \big)} & (e^{\big( \mathcal{C}_1 \big)} \text{could be replaced by another constant $\mathcal{K}$)}\newline \end{aligned} \]
In order to hold the property that \(f(X,Z)\) to be a probability function that sum to 1, we need to find \(\mathcal{K}\) such that \[ \mathcal{K} = \frac{1}{e^{(E_{z_2,z_3} (\ln p(X, Z))}} \]
This should always be possible since we can pick \(\mathcal{K}\) by arbitrarily split \(\mathcal{C}\) into \(\mathcal{C}_1\) and \(\mathcal{C}_2\). We have shown that indeed \(f(X,Y)\) is a probability function!
Having the property of \(\ln f(X,Y)\) assured, we could substitute it back to the \(\mathcal{L}\) equation. Namely:
\[\begin{aligned} \mathcal{L} & = \sum_{z_1} q(z_1) \Bigg[ \ln f(X,Y) + \mathcal{C}_2 \Bigg] - \sum_{z_1} q(z_1)\ln q(z_1)\newline & = \sum_{z_1} q(z_1) \Bigg[ \ln f(X,Y) \Bigg] - \sum_{z_1} q(z_1)\ln q(z_1) + \mathcal{C}_2 \sum_{z_1} q(z_1)\newline & = \sum_{z_1} q(z_1) \Bigg[ \ln f(X,Y) \Bigg] - \sum_{z_1} q(z_1)\ln q(z_1) + \mathcal{C}_2 & \text{(probability sum to 1)}\newline & = \sum_{z_1} q(z_1) \Bigg[ \ln \frac{ f(X,Y)}{q(z_1)} \Bigg] + \mathcal{C}_2 & \text{(by logarithm rule)}\newline & \text{Constant can be ignored as we are maximizing } \mathcal{L}\newline & = \sum_{z_1} q(z_1) \Bigg[ \ln \frac{ f(X,Y)}{q(z_1)} \Bigg] \end{aligned}\]Having only this portion of equation left to maximize, we should look back at the equation for finding KL-divergence. You should notice that this portion is actually the negative of KL divergence of \(q(z_1)\) and \(f(X,Y)\). That is, \[ KL(q(z_1)\mid \mid f(X,Z)) = - \sum_{z_1} q(z_1) \Bigg[ \ln \frac{ f(X,Y)}{q(z_1)} \Bigg] \]
Again, we could see that maximizing the \(\mathcal{L}\) terms is equals to minimizing \(KL(q(z_1)\|f(X,Z))\). Remember that the minimum of KL divergence could be achieved when \(q(z_1) = f(X,Z)\). So our goal could be restated as finding \(q(z_1)\) which equals to \(f(X,Z)\). You could trace back the equations above to see what the equations to compute \(f(X,Z)\).
This whole derivation is done with respect to \(q(z_1)\). Fortunately, we could do the exact same formulation for other variables in \(Z\). Therefore, this whole variational inference could be done by solving all equations of \(f(X,Z) = q(z_i)\) where \(z_i \in Z\). Namely, for this example, all equations that we need to solve are:
\[\begin{aligned} f(X,Y) = q(z_1) = \mathcal{K}_1 * e^{\sum_{z_2}\sum_{z_3} q(z_2) q(z_3) \ln p(X,Z)}\newline f(X,Y) = q(z_2) = \mathcal{K}_2 * e^{\sum_{z_1}\sum_{z_3} q(z_1) q(z_3) \ln p(X,Z)}\newline f(X,Y) = q(z_3) = \mathcal{K}_3 * e^{\sum_{z_1}\sum_{z_2} q(z_1) q(z_2) \ln p(X,Z)} \end{aligned}\]You could see clearly the pattern in the equations above! We could generalize this to arbitrary number of variables of \(Z\) ofcourse. We would then have the multiple summations over the product of all \(q(z_i)\).
We now know how to compute each \(q(z_i)\) individually and later use it to compute \(q(Z)\). However, you shall notice that as we compute \(q(z_1)\), we do not know \(q(z_2)\) and \(q(z_3)\). This applies for all \(q(z_i)\). We are now in the ‘chicken and egg situation’ where we have variables that dependent on each other to compute. On top of that, we also don’t know the value of any constant \(\mathcal{C}_i\).
One approach to address this problem is by using coordinate ascent inference. It is done by iteratively optimizing one variational distribution \(q(z_i)\) at a time, while holding the others fixed. We could first initialize these distribution randomly, and run this iterative algorithm until it converges.
Unfortunately, this series would not discuss the algoritm that address the given problem. But hopefully, after this series of post you’ll be better equipped when reading from another resources that cover more depth and perhaps in rigor manner. We will end this post here and see you in another post!
Fundamental points that we have learn so far is that: minimizing KL divergence is equivalent to maximizing the lower bound \(\mathcal{L}\), and it is easier and more convenient to do so. Now we will start the discussion that will be the gist of variational inference: how do we find \(q(Z)\) that maximizes \(\mathcal{L}\)?.
Lets suppose we have 2 sets of variables, namely \(X\) and \(Z\), where \(X = \{x_1, x_2, x_3\}\) and \(Z = \{z_1, z_2, z_3\}\). For the sake of convenient example, we pick 3 random variables for each (it could have been picked arbitrarily, though). We denote the join probability of these sets as \(p(X, Z) = p(x_1,x_2,x_3,z_1,z_2,z_3)\).
Remember that our ultimate objective is to find the conditional probability \(p(Z \mid X)\), namely: \(p(z_1, z_2, z_3 \mid x_1, x_2,x_3)\). By conditional probability rule, we can express the conditional probability as the joint over the marginal:
\[ \begin{aligned} P(Z|X) = \frac{p(z_1,z_2,z_3,x_1,x_2,x_3)} {\int_{z_1}\int_{z_2}\int_{z_3}p(z_1,z_2,z_3,x_1,x_2,x_3)d_{z_1}d_{z_2}d_{z_3}} \end{aligned} \]
As explained several times previously, finding the marginal as the denominator is really hard and sometimes impossible. So let us restate our goal again: we will try to find \(q(z_1,z_2,z_3)\) that approximates \(p(z_1, z_2, z_3\mid x_1, x_2,x_3)\) using what we already know: the joint probability \(p(z_1, z_2, z_3, x_1, x_2,x_3)\).
Putting back the pieces from the previous section, we can plugin these random variables into lower bound formula \(\mathcal{L}\), which we are trying to maximize. Namely:
\[ \begin{aligned} \mathcal{L} = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1, z_2, z_3) \ln \frac{p(x_1,x_2,x_3,z_1,z_2,z_3)}{q(z_1,z_2,z_3)} \end{aligned} \]
Note: on previous parts, we use integral notations as we assume that the variables are continuous. For the sake of clarity, we will use summation in this section.
Finding \(q(z_1,z_2,z_3)\) is also hard. So we tackle this problem by making a fundamental assumption that variables in \(Z\) are independent. Therefore, the joint becomes a product of each individual variable: \(q(z_1,z_2,z_3) = q(z_1) q(z_2) q(z_3)\). This assumption is the essence of mean field variational inference.
We can then plugin this back to lower bound equation and express it as: \[ \begin{aligned} \mathcal{L} = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln \frac{p(x_1,x_2,x_3,z_1,z_2,z_3)}{q(z_1) q(z_2) q(z_3)} \end{aligned} \]
At this point, we can simplify the lower bound equation with several steps of mathematical manipulation:
\[
\begin{aligned}
& \text{The first two steps would be to apply logarithm rule:} \newline
\mathcal{L}
& = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) - \ln q(z_1) q(z_2) q(z_3)\big] \newline
& = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) - \ln q(z_1) - \ln q(z_2) - \ln q(z_3)\big]
\end{aligned}
\]
One trick that is useful to solve this rather complicated equation is: Solve \(q(z_i)\) one at a time, instead of solving the entire \(q(Z)\). We can do this by making an assumption that all \(q(z_j)\) where \(j \neq i\), are known. Lets start with \(q(z_1)\) where we assume that we know \(q(z_2)\) and \(q(z_3)\). This assumption might not necessarily true, but we’ll see how it helps us. We can distribute the summations and break this large equation into 3 parts:
\[ \begin{aligned} \textbf{Part 1} & : \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln p(X,Z)\newline \textbf{Part 2} & : -\sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln q(z_1)\newline \textbf{Part 3} & : -\sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \big[ \ln q(z_2) + \ln q(z_3) \big] \end{aligned} \]
We go over this large equations part by part starting from the easiest one, part 3:
\[ \begin{aligned} & = -\sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \big[ \ln q(z_2) + \ln q(z_3) \big]\newline & = -\sum_{z_1} q(z_1) \sum_{z_2} \sum_{z_3} q(z_2) q(z_3) \big[ \ln q(z_2) + \ln q(z_3) \big]\newline & \text{by our assumption, we can replace the terms that contains } q(z_2) \text{ and } q(z_3) \text{ as a constant } \mathcal{C}.\newline & = -\sum_{z_1} q(z_1) \mathcal{C} \end{aligned} \]
Next one is part 2:
\[ \begin{aligned} & \text{We can do similar manipulation by swapping the summation:}\newline & = -\sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln q(z_1)\newline & = -\sum_{z_1} q(z_1) \ln q(z_1) \sum_{z_2} \sum_{z_3} q(z_2) q(z_3) \newline & \text{Notice that we have summations over probability distribution that sum to 1:}\newline & = -\sum_{z_1} q(z_1)\ln q(z_1) * 1\newline & = -\sum_{z_1} q(z_1)\ln q(z_1) \end{aligned} \]
The most complicated one would be part 1:
\[ \begin{aligned} & = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln p(x_1,x_2,x_3,z_1,z_2,z_3)\newline \label{eq:part3} & = \sum_{z_1} q(z_1) \sum_{z_2} \sum_{z_3} q(z_2) q(z_3) \ln p(x_1,x_2,x_3,z_1,z_2,z_3) & & (1) \end{aligned} \]
We can swap the summation again, but still we cannot remove the joint. Now lets take a look at the definition expectation of probability distribution: \[ \begin{aligned} E[x] = & \sum x p(x) \newline \text{Equivalently, for expectation of function }\newline \text{of random variable, we have:}\newline E[f(x)] = & \sum f(x) p(x) \end{aligned} \]
We see that each part of equation (1) corresponds to a part in the definition of expectation, you can look closely here: \[ \begin{aligned} = & \sum_{z_1} q(z_1) \underbrace{\sum_{z_2}\sum_{z_3}}_{summations} \overbrace{q(z_2) q(z_3)}^{\text{pdf p(x)}}\overbrace{\ln p(x_1,x_2,x_3,z_1,z_2,z_3)}^{\text{some function f(x)}} \end{aligned} \]
We, thus, can reexpress equation (1) to as follow: \[ \begin{aligned} = \sum_{z_1} q(z_1) E_{z_2,z_3}\big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) \big] \end{aligned} \]
Finally, we can assemble back all of those 3 parts into a formula of \(\mathcal{L}\): \[ \begin{aligned} \mathcal{L} & = \overbrace{\sum_{z_1} q(z_1) E_{z_2,z_3}\big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) \big]}^{part_1} \overbrace{- \sum_{z_1} q(z_1)\ln q(z_1)}^{part_2} \overbrace{-\sum_{z_1} q(z_1) \mathcal{C}}^{part_3}\newline & \text{We can take this further by moving part 3 and part 1 together:}\newline & = \sum_{z_1} q(z_1) \Bigg[ E_{z_2,z_3} \Big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) \Big] - \mathcal{C} \Bigg] - \sum_{z_1} q(z_1)\ln q(z_1) \end{aligned} \]
This is still far from done, but we will stop right here for this post. Hopefully, we can finish the derivation of finding the variational distribution in the next post. Stay tune!
This note will begin by discussing the reason why variational inference is useful in many applications. We will assume that some basic of probabilistic graphical models has already been covered by the reader. First we will start the discussion by considering the bayesian network below:
We know that the arrows in the network indicate the conditional independence between random variables. The joint probability density function (PDF) could be factorized as the product of these local independence. Namely, \[ P(x_1,x_2,…,x_5) = P(x_1)P(x_3|x_1)P(x_2|x_1)P(x_4|x_2,x_3)P(x_5|x_3) \] We can see that the graphical model is a very neat tools to determine the joint PDF factorization. Unfortunately, finding the conditional independence is still not so straightforward. Consider this example of finding: \(P(x_3,x_4\mid x_1,x_2,x_5)\) and remember this conditional probability rule: \[ P(Z|X)=\frac{P(X,Z)}{P(X)} \]
Assuming that \(x_1,x_2\) and \(x_5\) are continuous variables, the conditional distribution could be expressed as follow: \[P(x_3,x_4|x_1,x_2,x_5) = \frac{P(x_1,x_2,x_3,x_4,x_5)} {\int_{x_3}\int_{x_4}(x_1,x_2,x_3,x_4,x_5)d_{z_1}d_{z_2}d_{z_5}} \]
We can see that the conditional probability computation gets complicated as we have to marginalized the joint probability over the continuous variable. This computation become intractable when we don’t have analytical solution for the integrals.
One would questions: wouldn’t it be possible if we could somehow get the conditional probability just from the joint probability, completely bypassing step that involves intractable marginalization. The answer is: yes. In fact, it is the whole point of variational inference.
As the matter of fact, variational inference is not the only way to address this inference problem. We will briefly sum up the list of some widely known solutions in this table:
| Metropolis Hasting | Variational Inference | Laplace Approximation |
|---|---|---|
| Exact solution by a MCMC | Deterministic solution | (also) Deterministic solution |
| More accurate | Fairly good approximation | Much less accurate |
| Takes longer to compute | Takes less time to compute | Fast to compute |
Recall again the equation of conditional probability rule above. Lets break it down a little bit from the inference point of view. The conditional \(P(Z\mid X)\) for which we strive to find, is surely unknown. The denominator \(P(X)\) is also unknown (since the analytical solution of the integrals is unavailable). Meanwhile, the numerator, a joint probability \(P(X,Z)\) is easily known by factorizing the conditional independence.
Next, we define \(q(Z)\) as a probability distribution that approximate \(P(Z\mid X)\). We then use the property of Kullback–Leibler divergence, or, in short, KL divergence, that measure how one distribution diverge from other distribution (The next section discuss how this property is derived). The property is defined as follow:
\[ \begin{equation} \label{eq:klproperty} \begin{aligned} \ln p(X) & = \mathcal{L}(q) + KL(q||p) & & (1) \end{aligned} \end{equation} \]
In general, variational inference strive to approximate \(p(Z\mid X)\) by finding \(q(Z)\) that minimize the \(KL(q\|p)\). The value of the left hand side of the equation is fixed as the outcome of random variable \(X\) is given. However, terms in the right hand side can be changed by adjusting \(q(Z)\). Therefore, minimizing \(KL(q\|p)\) would be equivalent to maximizing \(\mathcal{L}(q)\) (in the next section, we will briefly show how \(\mathcal{L}\) acts as the lower bound of KL). We will break down the terms in the right hand side to show why it is particularly more convenient to maximize \(\mathcal{L}\) instead of minimizing \(KL(q\|p)\).
\[KL(q||p) = q(Z)\log \frac{p(Z|X)}{q(Z} \] \[ \mathcal{L} = \sum q(Z)\log \frac{p(X,Z)}{q(Z)}\]
\(KL(q\|p)\) contains probability distribution \(p(Z\mid X)\) which we try to avoid, while lower bound \(\mathcal{L}(q)\) contains known probability distribution \(p(X,Y)\). Thus, it is obvious that \(\mathcal{L}\) is preferable to be dealt with. In other words, variational inference finds \(q(Z)\) that maximize the lower bound \(\mathcal{L}(q)\). Namely, it finds \(argmax_{q(z)}q(z)\log \frac{p(X,Z)}{q(Z)}\). Finding \(q(Z)\) itself is not trivial and will be the subject of later sections.
Before diving deeper to variational inference, it is useful to understand how we could come up with the KL divergence property stated above. KL divergence is built on top of the entropy measure of a probability distribution (this quantity is first introduced in information theory). Intuitively, entropy provide a measure of how much information we could gain by knowing a certain probability distribution. This measure is given by the following equality: \[ \begin{equation} \label{eq:entropy} \begin{aligned} \mathcal{H}=-\sum_i P(x_i)\log P(x_i) \end{aligned} \end{equation} \] The equivalent measure for continuous random variable could be given by just replacing sum by integrals, namely: \[\mathcal{H}= -\int P(x_i)\log P(x_i)\]
KL divergence gives the quantity that indicates how similar two probability distributions are to each other. The smaller the value of \(KL(p\|q)\) is, the more similar distribution \(p\) is, to distribution \(q\). KL divergence defines the difference of the distributions as the distance of their entropy, well.. almost. The most straightforward expression of this intuition which one could think of would be:
\[ \begin{aligned} KL(p||q) & = \mathcal{H}_q - \mathcal{H}_p \newline & = (-\int q(x)\log q(x)) - (-\int p(x)\log p(x)) \end{aligned} \]
However, it should be emphasized that KL divergence measures the relative distance of the two distribution, which means that they are asymmetrical. In a way, the distance from \(q\) to \(p\) is not the same as the distance from \(p\) to \(q\). To incorporate this intuition, it modified the above expression to be as follow: \[ KL(p||q) = \big[-\int p(x)\log q(x)\big] - \big[-\int p(x)\log p(x)\big] \] Notice that the measure of KL divergence with respect to \(p\), would replace \(q(x)\) with \(p(x)\), and thus, making this function asymmetrical. This is, in fact, the actual equation of KL divergence.
While it is easy to explain the idea of KL divergence just by looking at the above formula, people would usually provide the simplified version of it. We will show how one could simplify the equation above:
\[ \begin{aligned} KL(p||q) & = \big[-\int p(x)\log q(x)\big] - \big[-\int p(x)\log p(x)\big] \newline & = \int p(x)\log p(x)) - \int p(x)\log q(x) \newline & = \int p(x) \Big[ \log p(x) - \log q(x) \Big] \newline & = \int p(x) \Big[ \log \frac{p(x)}{q(x)}\Big] \newline & = -\int p(x) \Big[ \log \frac{q(x)}{p(x)}\Big] & \text{(equivalent form by logarithm rule)} \end{aligned} \]
Knowing how KL divergence formula is derived, we could now explain the properties that it holds. The following are some useful properties of KL divergence:
This last section will be the gist of this first post of the series. We will now discuss about how to take all properties above to we could derive to equation (1). So let’s revisit variational inference objective which find \(q(Z)\) that approximate \(p(Z|X)\). Now we know that KL divergence is used to measure to quality of how good the approximation is, the smaller the value is the better the approximation is. Using KL divergence formula we can derive the expression as follow:
\[ \begin{aligned} KL(q(Z)||p(Z|X)) {} & = -\int q(Z) \Big[ \log \frac{p(Z|X)}{q(Z)}\Big] \newline & = -\int q(Z) \Big[ \log \frac{p(X,Z)}{p(X)} \frac{1}{q(Z)}\Big] & \text{(by conditional prob. rule)} \newline & = -\int q(Z) \Big[ \log \frac{p(X,Z)}{q(Z)} \frac{1}{p(X)}\Big] & \text{(swapping the denominator)} \newline & = -\int q(Z) \Big[ \log \frac{p(X,Z)}{q(Z)} + \log \frac{1}{p(X)}\Big] & \text{(by logarithm rule)} \newline & = -\int q(Z) \Big[ \log \frac{p(X,Z)}{q(Z)} - \log p(X)\Big] \newline & = \Big[ -\int q(Z) \log \frac{p(X,Z)}{q(Z)}\Big] + \Big[\log p(X) \int q(Z)\Big] & \text{(by distributing the integral)} \newline & = \Big[ -\int q(Z) \log \frac{p(X,Z)}{q(Z)}\Big] + \Big[\log p(X) \Big] & \text{(probability sum to 1)} \end{aligned} \]
We finally then arrive to this equality:
\[ \begin{aligned} \log p(X) & = KL(q(Z)||p(Z|X)) + \int q(Z) \log \frac{p(X,Z)}{q(Z)} \newline \log p(X) & = KL(q(Z)||p(Z|X)) + \mathcal{L} \end{aligned} \]
Where \(\mathcal{L}\) is a function that defined as: \(\mathcal{L} = \int q(Z) \log \frac{p(X,Z)}{q(Z)}\).
At this point, one should ask: how can this equality useful? Well, to begin with, we already know from the derivation in previous section that values of KL will always be \(>=0\). We also know that the value of \(p(X)\), is within \([0,1]\), and therefore \(\log p(X)\) will always be negative. We could then tell that \(\mathcal{L}\) will always negative.
\(\mathcal{L}\) could be interpreted as the Lower Bound of KL. One important thing to notice is that random variable X is given, and therefore the value of \(p(X)\) is fixed. Because of that, you can see that as the value of \(\mathcal{L}\) gets larger and closer to 0, the value of KL will grow smaller (also getting closer to 0).
We could essentially control KL by perturbing \(\mathcal{L}\)! Lets emphasize this point again: by making the lower bound \(\mathcal{L}\) less negative (larger), we are reducing the KL divergence. And so, reducing the KL divergence means that the quality of estimation of \(p(Z\mid X)\) by \(q(Z)\) is getting better.
That’s it for this part! I hope at this point you will have a glimps of what variational inference does and how KL-Divergence could intuitively be used to achieve the objective. In later parts we will discuss briefly the rationale of KL-Divergence lower bound, and how would variational inference approach find \(q(Z)\).