Active learning is a subfield of machine learning that probably doesn’t receive as much attention as it should. The fundamental idea behind active learning is that some instances are more informative than others, and if a learner can choose the instances it trains on, it can learn faster than it would on an unbiased random sample.
This post is a high-level primer with some Python code. If you want a serious introduction to the topic I strongly recommend you buy Burr Settles’ textbook.
how it works
Suppose we have a small collection of labeled data and a large collection of unlabeled data. We want more labeled data so that we can train a better classifier. We have the option of paying to get some of our unlabeled data points labeled. (This is not an unrealistic scenario — examples include hiring experts to annotate data, or conducting lab tests). If we have a fixed budget, what’s the optimal way to spend our money?
One suggestion is random sampling. This is a pretty reasonable thing to do and it generally works fine in practice. But in many cases, we can outperform random sampling by learning actively. The idea is to:
- Train a probabilistic supervised model on our labeled data
- Use the model to perform inference on the unlabeled data
- Identify instances that confuse our model, and request their labels
We then request labels for these points. Of course, we have to say more about what it means for a model to be “confused” by an instance. Proposing various definitions of this — and seeing how they cash out in experiments — is what active learning is fundamentally about.
The simplest approach to active learning is uncertainty sampling. Despite being simple, it’s typically competitive. It comes in three forms. The first is “least confidence” sampling, in which we seek instances which minimize the probability the classifier assigns to the predicted label. In Python, it might looks something like this:
# Perform inference: probs = clf.predict_proba(X_unlabeled) # Rank predictions by confidence: scores = 1 - np.amax(probs, axis=1) # Sort them: rankings = np.argsort(-scores) # Take the points for which we have a budget: X_active = X_unlabeled[rankings[:budget]]
Another approach is to seek points where the difference between the probabilities of the top two predictions is minimized. Effectively, looking for cases where the model is torn between two labels. This is called “max margin” sampling (it turns out this is what people do).
# Compute the margin of each instance: ordered = np.partition(-probs, 1, axis=1) margins = -np.abs(ordered[:,0] - ordered[:, 1]) # Sort them: rankings = np.argsort(-margins)
One might criticize the first two approaches on the grounds that they throw away all the information about the relative probabilities of the other labels. A third technique, which doesn’t share this feature, is to maximize the entropy of the output distribution:
from scipy.stats import entropy scores = np.apply_along_axis(entropy, 1, probs) rankings = np.argsort(-scores)
It’s not hard to prove that in binary classification, all three approaches are equivalent. Full implementations of these techniques (and a couple others) can be found here.
what it looks like
We can better see what is going on by letting a linear model learn actively on mnist. It’s well-known that the mnist instances that models tend to misclassify are not necessarily the same ones that people tend to misclassify. But if we look at the instances an active learner requests, they match our intuitions surprisingly well.
Above is a plot of the 50 most informative instances obtained after:
- Restricting the mnist task to a binary problem of “seven or not-seven”
- Training a linear model on 2000 random samples
- Performing uncertainty sampling
Most of them are either shittily drawn sevens, or other digits that have some of the distinctive properties of a seven. Additionally, the distribution of digits is heavily skewed to favour digits that are more similar in structure to seven.
On the one hand, it would be a mistake to fully equate the algorithm singling out a slanted “one” with a human asking whether the line across the top is an essential property of a “seven.” Statistical models of character recognition don’t operate in terms of necessary and sufficient conditions. On the other hand, it’s clear that both people and computers can glean insight from cases close to the margin for similar (if not identical) reasons.