Can we do a better KL approximation
This post discusses Monte Carlo approximations of the KL divergence, motivated by ideas from (cite) and (cite). While these works introduce the essential concepts, I found that several mathematical steps were either implicit or not fully formalized. Here, I aim to present a clearer and more rigorous derivation.
The KL divergence between distributions $q$ and $p$ is defined as:
\begin{equation} \label{eq:kl} \mathrm{KL}(q \,|\, p) := \sum_x q(x)\, \log \left(\frac{q(x)}{p(x)}\right) \end{equation}
Note that we can view the above as the mean of random variable $\log (r)$ where $r:= \frac{(x)}{q(x)}$ with $x\sim q$. So, \begin{equation} \label{eq:kl-exp} \mathrm{KL}(q \,|\, p) := \mathbb{E}_{x \sim q} \left[ -\log (r) \right] \end{equation}
Additionally, we can also sample $x \sim p$ with the above being rewritten as, \begin{equation} \label{eq:kl-exp-p} \mathrm{KL}(q \,|\, p) := \mathbb{E}_{x \sim p} \left[ -r \log (r) \right] \end{equation}
In many machine learning settings, including reinforcement learning for language models, we have access to the log-probabilities produced by the models, but the summation over the full support of $x$ is intractable, as I will explain in the next section.
In RL-based LLM training (such as GRPO), we maintain two policy models: the reference policy $\pi_{\text{ref}}(\theta)$ and the old or previous policy $\pi_{\text{old}}(\theta)$. The RL objective includes the KL term
\[\mathrm{KL}\big(\pi_{\text{ref}}(x \vert s) \,\|\, \pi_{\text{old}}(x \vert s)\big),\]where $s$ is the current state of the LM (this could be a prompt or context) and $x$ is the generated next token or token sequence. $x$, here, is sampled from a distribution over the token space. In other words, if the tokenizer has $V$ vocabulary entries, to compute the KL divergence we need to sum over $V$ entires and this quickly becomes intractable since $V$ is large. To complicate things further, in RL, we typically do multi-step rollout which means our space becomes exponential in $V$.
(add something about the multi step probabilities stuff)
If we cannot compute the expected value over the entire distribution space, we do the next best thing. Sample $N$ obvervations and compute the mean of that (Monte Carlo simulation).
So the process looks like this:
Sample $x_1, x_2, \cdots x_N \sim q$
Compute the $\log (r_i)$ for each sample
Estimate $\hat{KL} = -\frac{1}{N} \sum_i \log (r_i)$
Notice that this is an unbiased estimator since each $x_i \sim q$, the mean matches that form in Eq. (\ref{eq:kl-exp}).
However, it is high variance…
I am unsure about the origins of this estimator as to why someone would presume this would be a good estimator. Nevertheless, we can use $\frac{1}{2}(\log (r))^2$ as an estimator. To be precise we follow the same procedure as we did for $k_1$ except now, we compute $\frac{1}{2}(\log (r_i))^2$ for each $x_i \sim q$ we sample.
Notice that this is no longer unbiased and has lower variance.
We can show this empirically as well ((cite) already shows this by the way – I recreated the results in the table below) – $k_2$ has a suprisingly low bias and variance.
We need to delve deeper into the Fisher Information Matrix
Lemma 1: $\mathbb{E}[r] = 1$
Proof: This is easy to see. $\mathbb{E}[r] = \sum_{x \sim q} q(x) r(x) = \sum_{x \sim q} p(x) = 1$
This improves upon $k_1:= - \log (r)$ by using something called the control variate. In simple terms, we add a variable which is zero mean and negatively correlated with $k_1$ to inorder to reduce the variance. An easy choice for this (as suggested by cite) is $r-1$. Notice from Lemma 1, we have that it is zero mean. Thus, we have, for some $ \lambda$,
\[k_3 := -\log (r) + \lambda (r-1)\]The hard part is choosing the $\lambda$