Loss functions
In this chapter, we'll leverage our newfound KL superpowers to tackle machine learning. We will understand a crucial aspect of it: setting up loss functions.
The main point here is that KL divergence and its friends cross-entropy & entropy are a guiding beacon: Starting with a rough idea about important aspects of data, they help us build a concrete estimation algorithm. Specifically, we can use the maximum entropy principle to transform our initial concept into a fully probabilistic model, and then apply maximum likelihood to derive the loss function that needs to be optimized.
Maximum entropy and maximum likelihood principles explain many machine-learning algorithms.
Good old machine learning
In the following widget, you can explore four problems. The first one, estimating the mean & variance of a bunch of numbers, is a classical statistics problem. The following ones - linear regression, -means, logistic regression - are "good old machine learning problems", it's the kind of stuff you see in ML 101 class before diving into neural networks.
I want to convey how it all makes sense. Starting with some simplistic visual feel for what we want (like the picture in the widget's canvas), maximum entropy gives us a concrete probabilistic model for what's happening. We can then minimize the cross-entropy (or, if you want, use maximum likelihood approach) to find the best parameters.
Given a set of numbers , how do we estimate their mean and variance? We've already approached this riddle from various angles. Now, let's combine our insights.
First, we transform the general idea that mean and variance are important into a concrete probabilistic model. The maximum entropy principle suggests modeling the data as independent samples drawn from the Gaussian distribution.
Once we have a set of possible modelsβall Gaussian distributionsβwe can select the best one using the maximum likelihood principle.
We want to find that maximize
It's typically easier to write down the logarithm of the likelihood function. As we discussed, we can call it cross-entropy minimization or log-likelihood maximization. In any case, the problem simplifies to this:
There are several ways to solve this optimization problem. Differentiation is likely the cleanest: If we define to be the expression above, then:
Setting leads to . Similarly,
Setting then leads to .
What I want to emphasize is how our only initial assumption about the data was simply, "we have a bunch of numbers, and we care about their mean and variance." The KL divergence framework that reduced the rest of the problem to running the math autopilot.
Let's talk about our statistics riddle.
If all the above applications are a bit too easy, here's a more complicated one - setting up the loss of a particular neural network architecture for working with images, called variational autoencoders.
What's next?
We are done with optimizing KL divergence and setting up loss functions! In the final chapter, we will see how entropy, cross-entropy, and KL divergence are connected to text encoding and what it tells us about large language models.