Approximating KL Divergence (and its applications in RL)

Can we do a better KL approximation

The What?

This post discusses Monte Carlo approximations of the KL divergence, motivated by ideas from (cite) and (cite). While these works introduce the essential concepts, I found that several mathematical steps were either implicit or not fully formalized. Here, I aim to present a clearer and more rigorous derivation.

The KL divergence between distributions $q$ and $p$ is defined as:

\begin{equation} \label{eq:kl} \mathrm{KL}(q \,|\, p) := \sum_x q(x)\, \log \left(\frac{q(x)}{p(x)}\right) \end{equation}

Note that we can view the above as the mean of random variable $\log (r)$ where $r:= \frac{(x)}{q(x)}$ with $x\sim q$. So, \begin{equation} \label{eq:kl-exp} \mathrm{KL}(q \,|\, p) := \mathbb{E}_{x \sim q} \left[ -\log (r) \right] \end{equation}

Additionally, we can also sample $x \sim p$ with the above being rewritten as, \begin{equation} \label{eq:kl-exp-p} \mathrm{KL}(q \,|\, p) := \mathbb{E}_{x \sim p} \left[ -r \log (r) \right] \end{equation}

In many machine learning settings, including reinforcement learning for language models, we have access to the log-probabilities produced by the models, but the summation over the full support of $x$ is intractable, as I will explain in the next section.

The Why?

In RL-based LLM training (such as GRPO), we maintain two policy models: the reference policy $\pi_{\text{ref}}(\theta)$ and the old or previous policy $\pi_{\text{old}}(\theta)$. The RL objective includes the KL term

\[\mathrm{KL}\big(\pi_{\text{ref}}(x \vert s) \,\|\, \pi_{\text{old}}(x \vert s)\big),\]

where $s$ is the current state of the LM (this could be a prompt or context) and $x$ is the generated next token or token sequence. $x$, here, is sampled from a distribution over the token space. In other words, if the tokenizer has $V$ vocabulary entries, to compute the KL divergence we need to sum over $V$ entires and this quickly becomes intractable since $V$ is large. To complicate things further, in RL, we typically do multi-step rollout which means our space becomes exponential in $V$.

(add something about the multi step probabilities stuff)

The naive first estimator for KL Divergence ($k_1$)

If we cannot compute the expected value over the entire distribution space, we do the next best thing. Sample $N$ obvervations and compute the mean of that (Monte Carlo simulation).

So the process looks like this:

  1. Sample $x_1, x_2, \cdots x_N \sim q$

  2. Compute the $\log (r_i)$ for each sample

  3. Estimate $\hat{KL} = -\frac{1}{N} \sum_i \log (r_i)$

Notice that this is an unbiased estimator since each $x_i \sim q$, the mean matches that form in Eq. (\ref{eq:kl-exp}).

However, it is high variance…

The second slightly better estimator for KLD ($k_2$)

I am unsure about the origins of this estimator as to why someone would presume this would be a good estimator. Nevertheless, we can use $\frac{1}{2}(\log (r))^2$ as an estimator. To be precise we follow the same procedure as we did for $k_1$ except now, we compute $\frac{1}{2}(\log (r_i))^2$ for each $x_i \sim q$ we sample.

Notice that this is no longer unbiased and has lower variance.

We can show this empirically as well ((cite) already shows this by the way – I recreated the results in the table below) – $k_2$ has a suprisingly low bias and variance.

Showing $k_2$ has low variance

We need to delve deeper into the Fisher Information Matrix

Lemma 1: $\mathbb{E}[r] = 1$

Proof: This is easy to see. $\mathbb{E}[r] = \sum_{x \sim q} q(x) r(x) = \sum_{x \sim q} p(x) = 1$

The best estimator for KLD ($k_3$)

This improves upon $k_1:= - \log (r)$ by using something called the control variate. In simple terms, we add a variable which is zero mean and negatively correlated with $k_1$ to inorder to reduce the variance. An easy choice for this (as suggested by cite) is $r-1$. Notice from Lemma 1, we have that it is zero mean. Thus, we have, for some $ \lambda$,

\[k_3 := -\log (r) + \lambda (r-1)\]

The hard part is choosing the $\lambda$