Multiplicative weights update

In the chapter about machine learning, we've seen how KL divergence guides us all the way from having a rough guess about our data to a well-defined optimization problem with a concrete loss function. This is where KL divergence really shines and shows us the way.

But there's one more step missing -- how do we actually solve that final optimization problem? This is where understanding KL isn't necessarily the most important thing. In a few lucky cases (like estimating mean/variance or linear regression), we have explicit formulas for the solution and can just plug numbers into those formulas.

Most machine learning problems, however, are NP hard (kk-means, neural net optimization) and we try to solve them using some kind of locally-improving algorithm, typically gradient descent. This chapter is about a variant of this algorithm called multiplicative weights update.

In this chapter, we'll understand how the algorithm connects to KL divergence and discuss a few applications.

💡If there is one thing you remember from this chapter...

Multiplicative weights update is a useful algorithmic trick based on treating probabilities multiplicatively.

Basic algorithm

Let's return to one of our riddles:

Say we want to get rich by investing in the stock market. Fortunately, there are nn investors willing to share their advice with us: Each day tt, they give us some advice, and at the end of the day, we learn how good that advice was—for the ii-th expert, we'll have gain gi(t)g_i^{(t)}.

Our general investing strategy is this: We start with a uniform distribution p1(0),,pn(0),pi(0)=1/np_1^{(0)}, \dots, p_n^{(0)}, p_i^{(0)} = 1/n over the experts. At the beginning of each day, we sample an expert from this distribution and follow their advice. At the end of the day, we look at the gains gi(t)g_i^{(t)} and update p(t)p^{(t)} to p(t+1)p^{(t+1)}.

The question is: how should we update? Let's discuss three possible algorithms from the riddle statement.

Gradient descent / proportional sampling

One way is using gradient descent:

where ε\eps is the learning rate of the algorithm. After this update, we would of course normalize the probabilities so that they sum to one.

If we define Gi(t)G_i^{(t)} as the total accumulated gain, i.e., Gi(t)=gi(1)++gi(t)G_i^{(t)} = g_i^{(1)} + \dots + g_i^{(t)}, then we can rewrite [Eq. gradient?] simply as

Notice that the learning rate doesn't matter due to the normalization, so we end up with the proportional sampling rule from the riddle statement:

This proportional rule has some problems. In the widget, you can notice that if the first expert has gained 67andtheotherexpertgained67 and the other expert gained 33, the proportional rule goes with the top expert only with 2/3 probability, although it should ideally go for her with close to 100% probability.

Follow the leader

Maybe sampling the next expert is not the right approach and we should simply pick the best expert so far? This algorithm—called follow the leader—is very natural and works mostly very well. There are some edge cases, though, when it behaves poorly (see the widget).

Multiplicative weights update

Multiplicative weights update algorithm is like the gradient descent rule [Eq. gradient?], but uses multiplicative instead of additive updates:

This algorithm interpolates between gradient descent (the limit of this algorithm for ε0\eps \rightarrow 0) and follow the leader (the limit for ε\eps \rightarrow \infty).

If we set ε\eps in the right way, this turns out to be an amazing algorithm combining the strengths of both approaches. The remarkable property is that if you run multiplicative weights algorithm for tt steps and set ε\eps to be about 1/t1/\sqrt{t}, the algorithm is always almost as good as the best expert! In particular, it gains at most O(t)O(\sqrt{t}) dollars less than the best expert in expectation. 1 In other words, whatever sequence of expert gains we could come up with in our widget, MWU would always end up on top.

Intuitions

Let me give you three intuitions for why multiplicative weights update algorithm makes sense.

Bayes updating

You can think of our problem as "trying to find the best expert". In this sense, the algorithm's probability distribution over experts is its "guess about who the best expert is". Whenever we learn how well each expert performed, this corresponds to learning "the likelihood of each expert being the best". Bayes' rule says our guess should be updated multiplicatively by multiplying by the likelihood—which is exactly what the algorithm does.

Notice that a more precise name for the MWU algorithm might be multiplicative probability update since the weights are typically thought of as probabilities. And frankly, that name sounds like a description of Bayes' rule.

Softmax intuition

Let's think about all the information we learn in the first tt steps of our investing game. Each expert ii accumulated a total gain Gi(t)=gi(1)++gi(t)G_i^{(t)} = g_i^{(1)} + \dots + g_i^{(t)}. How should this total gain translate to the probability that we sample them in the t+1t+1-th round?

We understand this! In the chapter on max entropy distributions, we've seen how the softmax function is the right generic way of converting numbers to probabilities. So we should have . The constant λ\lambda is a proportionality constant we have to decide on by some other means.

But look -- this is exactly what the multiplicative weights algorithm is doing. It updates pip_i by multiplying it by in each step, so pi(t)p_i^{(t)} is proportional to . So what we call λ\lambda in softmax is called the learning rate ε\varepsilon in multiplicative weights.

Geometric intuition

As we've seen in the preceding chapter, it's really nice to think in terms of optimization problems that optimize loss functions. In fact, one step of gradient descent or multiplicative weights can be seen as optimizing a certain loss function. Let's see how.

In our optimization problem, we start with this data: the current distribution p(t)p^{(t)} and loss function g(t)-g^{(t)} (loss functions are to be minimized, hence the negative sign). Notice how we try to balance two things:

  1. We want the new probability distribution p(t+1)p^{(t+1)} to be close to the original one p(t)p^{(t)}, whatever "close" means.

  2. If we sample a random expert from p(t+1)p^{(t+1)}, we want them to be as good as possible for the current loss function. That is, we want to minimize i=1npi(t+1)(gi)\sum_{i = 1}^n p^{(t+1)}_i \cdot (-g_i).

With this in mind, here's an equivalent formulation of gradient descent and multiplicative weights:

I'll leave it to you to check that the solution of the first optimization problem is pi(t+1)=pi(t)+εgi(t)p_{i}^{(t+1)} = p_i^{(t)} + \eps \cdot g_i^{(t)}2, while the solution to the second optimization problem is .3

What I want to focus on is how both algorithms try to balance the same two constraints (stay close to original solution, make new one good on current loss), and their difference is pretty much only in how they measure distance in the space of all probability distributions.

Gradient descent corresponds to assuming we should use the standard Euclidean distance (or 2\ell_2 distance) to measure distance between distributions. This is an amazing distance, and if you're optimizing in a complicated multidimensional space, it's a natural choice!

But in the context of our getting-rich riddle, we're optimizing probabilities -- after all, even gradient descent has to be altered to keep projecting its solutions to the probability simplex to make any sense. In this space, the most meaningful way of measuring distances isn't Euclidean distance, but KL divergence. Hence, multiplicative weights.

But wait! KL divergence can't measure distance. We've emphasized how it's important that D(p,q)D(q,p)D(p,q) \not= D(q,p), which seems pretty bad for measuring distance! Fortunately, we're talking about a setup where both distributions p=p(t),q=p(t+1)p = p^{(t)}, q = p^{(t+1)} are going to end up very close to each other, and in that case, D(p,q)D(q,p)D(p,q) \approx D(q,p) (more about that in the Fisher information section!).

The general algorithm that generalizes both gradient descent and multiplicative weights is called mirror descent. The algorithm needs to be supplied with a distance-like function. For example, KL divergence works for mirror descent, even though it's asymmetric. It's beyond our scope to analyze what properties a function needs to satisfy to work inside mirror descent. The class of functions that work is called Bregman divergences. This is why KL divergence is called this way -- it's a particular instance of a Bregman divergence.

Applications

Let's briefly discuss some applications of the algorithm.

Machine learning applications

Multiplicative weights are a super important algorithm in machine learning. Essentially, whenever you want to recommend users the next video, provide them some kind of web search, or aggregate several algorithms together, you're solving a variant of our experts problem. In reinforcement learning, this is called the multi-armed bandit problem and it's a bit more complicated than our setup. In our setup, at the end of the day we learn how well each expert performs. But in the bandit setup, we only learn the performance of the chosen expert. For example, if you recommend a certain video, you get feedback whether the user liked it, but you don't get any feedback regarding whether the user likes all the other videos you considered recommending.

This lack of knowledge about other experts makes this a problem about balancing exploration and exploitation. The follow-the-leader strategy that worked pretty well above (except for certain special cases) is based on exploiting the knowledge of who the best expert is. But the strategy can't work by itself because you also have to explore to "find" the best expert. Fortunately, variants of the multiplicative weights algorithm still work pretty well for the bandit problem - they can very nicely balance the tradeoff between exploring (we have a distribution over all experts) and exploiting (good experts are going to dominate it).

Algorithmic applications

Here's my favorite algorithmic application of multiplicative weights that also shows how multiplicative weights connect to the idea of Bayesian updating.

Consider this problem: We're given a sorted array with nn numbers and want to find whether it contains 4242. What's the fastest algorithm?

That's easy -- just do binary search. It finishes in O(logn)O(\log n) steps. The binary search algorithm always looks up the number in the middle of the current list and examines its value. Based on that, it can either discard the left half or the right half of the array, and continue with the other half.

Here's a more complicated problem: there are again nn numbers a1,,ana_1, \dots, a_n and we want to find whether they contain some xx. But this time, whenever we compare aja_j with xx, the comparison fails with 1/31/3 probability. That is, if aj<xa_j < x and we compare the two, with 1/31/3 probability the comparison tells us that aj>xa_j > x. If aj=xa_j = x, let's assume for simplicity that we reliably learn this.

This is called the noisy binary search problem and has many applications—sometimes, comparing two objects is more complicated than comparing two numbers and may involve doing some kind of experiment which could fail.

How do we solve this noisy problem? Here's a simple algorithm that finishes in expected O(logn)O(\log n) steps (assuming xx is in the input sequence). Every item in the array starts with weight 1/n1/n and we'll keep the weights summing to 11. In each step, we compare xx with the element ii that is "in the middle" with respect to the weights: i.e., we choose ii such that a1++ai1<1/2a_1 + \dots + a_{i-1} < 1/2 and ai+1++an<1/2a_{i+1} + \dots + a_n < 1/2.

If x<aix < a_i, we got some evidence that xx lies to the left, so we multiply all the weights to the left by two (and normalize the weights afterwards to sum to 1). If x>aix > a_i, we behave analogously, and if x=aix = a_i we finish.

Let's see this algorithm in action! 4

Noisy Binary Search Visualization

1
6.7%
2
6.7%
3
6.7%
4
6.7%
5
6.7%
6
6.7%
7
6.7%
8
6.7%
9
6.7%
10
6.7%
11
6.7%
12
6.7%
13
6.7%
14
6.7%
15
6.7%
Steps: 0

You can see how this algorithm implements the general idea behind multiplicative weights: Each element of the array is kind of like an expert. You can also see how this ties nicely with our Bayesian intuition about the multiplicative weights algorithm: You can interpret our starting weights as the uniform prior for which element of the array is xx (recall we assume that xx is present) and each step of the algorithm is like a Bayesian update that updates these probabilities.

Let's also revisit our original investment problem and see how different algorithms perform in various scenarios:

Multiplicative Weights Update in Action

Select Algorithms

Choose Scenario

The important idea

The focus of our minicourse was on how KL divergence and related techniques can be used to model the world around us. I think multiplicative weights are a good example of a different application -- KL can guide us to come up with the right algorithm to approach a problem!