The story of tonight is I was thinking a little bit about the KL divergence and went down a rabbithole. We generally learn about the KL divergence as a sort of kind of distance between two distributions, except it's not actually a distance and be careful about swapping the order (this was what it felt like to me, at least). That's a little bit unsatisfying, especially since the KL shows up literally everywhere. Why do we choose one order and not another? Why do we pick this particular form for the KL divergence, and are there other reasonable things we could write down?

At a high level, I'm going to discuss f-divergences, why they're useful, and how we can express them in a variational form. We'll see that Fenchel duality allows us to magically go from an elegant but not very useful expression to a variational expression that we can optimize against. This will turn out to generalize many common divergences between probability distributions.

f-divergences

An f-divergence is a generalized way to compare distributions. Explicitly

$$\mathcal{D}_f(P\Vert Q) = \mathbb{E}_{Q}\left[f(r(x))\right]$$

Here, $f$ is some convex function we haven't specified, and $r(x)$ is the Radon-Nikodym derivative. Intuitively, $r(x)$ gives us a way to relate probability distributions; mathematically, it's a function that satisfies

$$\nu(E) = \int_E r\, d\mu$$

for measures $\nu, \mu$ and measurable sets $E$. If both measures have densities (say $p$ and $q$) with respect to the Lebesgue measure, then $r(x) = p(x) / q(x)$ is called the likelihood ratio.

Immediately, there are some interesting cases we should consider

$f(u) = u\log u$: Then we have

$$\mathcal{D}_f(P\Vert Q) = \mathbb{E}_{Q}\left[r\log r\right] =\mathbb{E}_{Q}\left[(p/q)\log (p/q)\right] = \mathbb{E}_{P}\left[\log(p/q)\right]$$

where we just note that $\mathbb{E}_Q[r f] = \mathbb{E}_P[f]$. In other words, we just recovered the normal forward KL divergence.

$f(u) = |u-1|/2$. Then we have

$$\mathcal{D}_f(P\Vert Q) = \mathbb{E}_{Q}\left[\frac{|(p/q)-1|}{2}\right] =\frac{1}{2}\int|p-q|\,dx$$

which is the total variation distance.

A neat thing is that we can plausibly construct $f$ to emphasize properties that we care about. For instance, $u\log u$, grows slightly faster than linearly, so larger $u$ is heavily penalized. Since $u$ is the likelihood ratio $p/q$, this means we heavily penalize places where $p$ has more mass than $q$. In the places where $q$ has some mass but $p$ has none, however, the likelihood ratio is $0$, and here $u\log u\rightarrow 0$. This is consistent with our understanding of the KL divergence; it cares a lot when we miss stuff, but not so much when we cover extra stuff.

More generally, the main requirement on $f$ is that $f(u)\rightarrow 0$ as $u\rightarrow 1$, which is saying that wherever the likelihood ratio is $1$ (i.e., $p(x) = q(x)$) shouldn't contribute to the divergence between the probability distributions. This gives us a nice generalization of probability divergences, since we can just provide a new $f$ that behaves nicely and we get a new divergence.

Dual problem

What if we don't have the likelihood ratio available to us, but would still like to estimate the divergence? The key point here is we can use Fenchel duality to express a general f-divergence as a difference of expectations, which we can estimate using standard Monte-Carlo methods. The main trick is to substitute

$f(u) = \sup_t \{\langle t, u\rangle - f^*(t)\}$

which holds for convex functions. Here,

$f^*(t) = \sup_u \{\langle t, u\rangle - f(u)\}$

is the function conjugate.

$$D_f(P||Q) = \mathbb{E}_Q\left[ \sup_t \{ tr - f^*(t) \}\right] = \sup_T\left\{ \mathbb{E}_P[T] - \mathbb{E}_Q[f^*(T)] \right\}$$

Some subtleties: we swapped a supremum and an expectation, which we need to be a bit careful about and we switched to a function $T(x)$ instead of a constant $t$ to reflect that (the initial inner supremum is over all constant $t$ for a particular $x$, so when the supremum goes outside it needs to now depend on $x$). The main trick here is between the second and the third line, where we've gotten rid of the likelihood ratio by noticing that $\mathbb{E}_Q[r f] = \mathbb{E}_P[f]$.

I found this quite neat! For many simple convex functions, we can manually compute the conjugate by taking derivatives. What we have is an expression for any f-divergence entirely in terms of expectations. However, there is that nasty supremum out in front... and I don't think we can magic that away. In practice, though, we approximate $T$ using a rich function class (a neural network), and we can use standard gradient based optimization to optimize an objective using sample averages, i.e.,

$$L(\theta) := \left(\frac{1}{n}\sum_{i}^n T_{\theta}(x_i) - \frac{1}{m}\sum_{j=1}^m f^*(T_{\theta}(y_j))\right) \leq D_f(P\Vert Q)$$

Why is this formulation useful?

This way of thinking generalizes a lot of standard concepts and tools in deep learning. For example, if you pick your divergence to be the Jensen-Shannon divergence and use a sigmoid parameterization, you get the standard GAN loss. More broadly, I think different domains suffer from different flavors of distributional mismatches, and these lines offer an interesting way to construct objectives that might be better suited to address those mismatches. For instance, we're much more concerned about distributional coverage in generative biology than we probably are in, say, image generation. But also, it's just very neat to see how Fenchel duality allows us to turn an integral that includes a likelihood ratio into something that can be computed entirely over samples!