What everyone needs to know about interpretability in machine learning
Note: this post was written for a general audience, and assumes only passing familiarity with machine learning.
For anyone who’s been paying attention, it should be apparent that statistical machine learning systems are being widely deployed for automated decision making in all kinds of areas these days, including criminal justice, medicine, education, employment, policing, and so on. (For a great overview of the hazards of applying machine learning in these domains, check out Cathy O’Neil’s book Weapons of Math Destruction).
Particularly with the recently enacted GDPR — the new European regulation about data and privacy — there is growing interest in having systems that are interpretable, that is, we can make some sense of why they are making the prediction that they are making. To borrow an example from Been Kim, if a computer tells you that you need surgery, you’re probably going to ask for some sort of explanation.
There is now essentially a whole subfield of research devoted to interpretability in machine learning, so there’s no chance of covering all of that here. However, given how much confusion seems to be taking place, I thought it would be useful to outline a few essential ideas that everyone should know about this area.
1. Machine learning systems make predictions based on a set of input features (i.e. a bunch of numbers).
This point is in some sense so obvious that it’s rarely discussed, but it is actually quite important. Although machine learning is being used for all kinds of data, including images, videos, and natural language, the first step of any such system is to convert the input into a set of numbers. In the case of a medical image, for example, the image would be converted into a set of numbers representing pixel intensities, one for each pixel.
For images, this will be a very high fidelity representation, as this is in some sense how the data is natively acquired. For other domains, however, the mapping may involve some loss of information. For example, in working with text data, the text is typically processed in a way that involves some sort of simplification, such as dropping rare words. The words themselves would be converted to numbers, perhaps by using an index which maps each unique word type to a unique numeric representation (or other more complex variations on this).
At this point you are probably thinking “but, I am large! I contain multitudes! I cannot be represented by a set of numbers!” and in many ways you would be correct. For example, the much-discussed ProPublica story on the use of risk assessments in criminal justice focused on a particular system that took as input the answers to 137 questions. Although this still produces a vast space of possible inputs, it seems trivially obvious that these questions will fail to capture all the nuance of each individual. Any individuals who provide the exact same set of answers would be treated as functionally identical, even though there would almost certainly be differences between them that were not captured by the questions.
In some cases, these details may not matter all that much. In medical diagnosis, for example, a doctor may not need to know everything about you in order to diagnose your illness; perhaps a simple echo-cardiogram will do. In other areas it is much less clear how much information is needed. Credit scoring is almost certainly effective to some degree, given the financial incentives involved, but making predictions is always hard (yes, especially about the future), and it is difficult to know whether there is additional information out there that could have been helpful.
Although modern machine learning systems are often described as being good at representation learning, meaning that less work is required in terms of hand-crafting features (intelligent ways of combining the inputs), they still operate on a set of numbers, and anything which is not included in the input data will never be considered. Although we have many ways of determining if a particular feature can safely be ignored, we are not particularly good at knowing whether we are missing some critical piece of information.
There is much more to be explored here, particularly the question of who has the power to define how people will be represented, but the main thing to keep in mind is that representing inputs to a machine learning system (or “model” or “algorithm”) always involves choices. If you’re thinking about interpretability, the first question to ask is what data does the system take as input, how was this decided, and what might be missing?
2. Machine learning discovers correlations in data (but does not understand causality).
This point is a bit more subtle than the first one, but the key idea is embedded in the old phrase that “correlation does not imply causation.” Although there is some exciting work happening in causal machine learning these days, the vast majority of applied systems completely ignore it, and only focus on correlations.
What does this mean? Basically if I see a pattern in the input data that always (or frequently) occurs with a particular outcome (e.g. cancerous or benign), then I will quite sensibly predict that output when I see that input pattern. The whole idea of supervised learning is automatically discovering patterns in large amounts of data that allow us to map from a given input to a predicted label or outcome. This does not, however, imply that there is any direct causal connection between the outcomes and the patterns that have been found.
For example, think of the famous “marshmallow experiment”. Give a child a marshmallow and tell them that if they don’t eat it right away, they can have two later. In addition to the supposedly delightful videos of children trying to fight their instincts, this experiment produced the finding that those who are able to delay their gratification will have better life outcomes in various ways (on average). Although one interpretation is that eating the marshmallow somehow made things worse for those who could not resist doing so, a much more reasonable interpretation is that there is some latent property which is being measured, such as self-control, and that this explains both the marshmallow eating and the later life events.
In the same way, machine learning might learn a correlation between a certain pattern of behaviour (such as cycling to work) and later outcomes (such as not dying). This might then be used to predict future behaviour, but this should not be interpreted as one thing causing another. Rather, the system is as robust as the correlation (which could be the result of a causal effect).
Because machine learning systems work by discovering patterns in the data they were trained on, they tend not to be robust to changes in the distribution of input data. When a system has been trained on one dataset (such as WEIRD people), and is then applied to another, things can be expected to fail arbitrarily badly. This phenomena shows up in things like the failure of Google’s flu trends, or computer vision systems that only work for some skin tones. Where things get really complicated is when a system itself starts interacting with the environment, such that people start modifying their own behaviour in response to a system’s output. In such a case, all bets are truly off.
3. Some models are special, but interpretability is not the norm.
Given the above two points, how does supervised learning work? Well, the details of how we obtain a system can be complicated, but the end result is almost always the same. We start with some dataset used for training, and we end up with a system that takes a set of numbers as input and returns a different number as a prediction, based on correlations discovered in the training data.
For example, consider the question of how much money a house will sell for. We can dream up any number of numbers we might want to use to describe a house, such as number of bedrooms, square footage, the average price of a house in the neighbourhood, and so on, and learn a model that will map from these numbers to a dollar amount. A natural way to approach this is to get a large dataset of houses that have sold in the past, including how much they sold for, and then to train a model on this data to predict sale price from these features.
If we use a very simple model, we might well find some nice simple patterns, such as larger houses sell for more money. But perhaps this is only true up to a point; or perhaps it depends on the age of the house, because older houses tend to be bigger. As we start to consider more and more complicated possibilities, we will get a richer model, but also one that is harder for us to make sense of in any intuitive way. In the extreme, we basically end up with something that might as well be a black box, in that we simply input all the feature values, and get the predicted price in return. Note that it is not actually a black box, in that we can directly inspect the computations involved, but for complicated models, this may not be very meaningful. The main point is that in many ways the defining property of such systems is an explicit mapping from any possible point in the input space to a predicted output.
Now, it turns out that mappings allowed by certain classes of models can be summarized in a way which can easily be understood, interpreted, and simulated by humans. The classic example here is linear models, where the output is a weighted sum of the input values. For such a model, we can summarize how the output will change in response to a change in one input feature using a single number (the weight on that feature). This means that it is much easier for humans to feel like they understand how the model is making predictions.
An even simpler case is a decision list, which takes the form of a series of yes/no question. Although such a model has a simple interpretation that is easy for us as humans to process and describe in words, it is the same as any other model in that it takes a set of numbers (a set of 1s and 0s for yes and no), and returns another number as a prediction. For any machine learning system, once we have defined the input space (e.g., the set of questions), we can give it any possible set of numbers in that space, and it will return a prediction, even if it hasn’t seen that input pattern before. In some sense, the “magic” of machine learning is the ability to generalize from patterns that have been discovered to new combinations of inputs that have not previously been seen. However, this will always be a somewhat limited ability, potentially fragile, and dependent on various assumptions.
Models like linear models and decision lists are special, in that this mapping from inputs to outputs can be described in a compact way that humans can easily mentally simulate, but we should not necessarily expect this to be the norm. Especially when dealing with deep learning models and very large input spaces (such as images or language), the whole explanation for why a model made a particular prediction is that you gave it a particular set of inputs, and it performed a series of linear algebra operations that produced the results. Give it a different set of inputs and you will get a different prediction.
Although there is lots of work on how we might be able to provide easy-to-understand explanations which approximate the truth, this necessarily entails some degree of inaccuracy. Realistically, once a system has been trained, there should be no expectation that it can necessarily be simplified to something that has a compact explanation, at least not without some loss of fidelity to what the system is actually doing.
Is the situation hopeless then? Why no! That brings me to my last point:
4. The real question is: how was the system created?
Given that the mapping of a machine learning system from inputs to outputs seems so potentially arbitrary, one should rightfully ask, why is the model configured to make those specific predictions for those specific inputs? That is an excellent question, and the answer is all in how it was trained.
In particular, there are three essential elements that can be assayed. First, there is the question of what data was used to train the model. As we know, machine learning will discover patterns in the data it was trained on, but what data was originally used for this purpose? If it was a dataset of images of faces, whose faces? If it was credit histories, who was in the database? This has important implications for a potential mismatch between the training data used and where it is being deployed, as described above, but it also raises the question of whether there may be particular biases baked into the dataset which unfairly reflect past inequalities.
Second, there is the question of how the data was represented (see the first point above). In some cases this will go hand in hand with the question of what training data was used, as lots of datasets have effectively been pre-coded into a particular representation. Nevertheless, we should not assume that this is the full picture, or the only way that such data could have been represented.
Finally, there is the question of what type of model was chosen and how it was trained. There is a huge range of possibilities here, but the details need not concern us. The point is not necessarily to ask whether the “right” choice was made here, but rather to acknowledge and investigate what else might have been possible. For certain types of models, we are guaranteed to get the same result for the combination of a choice of model and a particular dataset with a particular representation. For deep learning, by contrast, there is some inherent randomness involved, such there is always the possibility of ending up with a different set of model weights, even due to factors such as how parameters are initialized.
Of course, given a sufficiently large dataset, these nitty-gritty details about training are unlikely to make much difference. But, as data grows increasingly high-dimensional (think of all the different attributes of people that might be relevant to any particular prediction; in the extreme, think of the billions of base pairs in our DNA!), what might have seemed like lots of data, can suddenly feel quite sparse in the tails of the distribution. Particularly for unusual cases, it is highly possible that different choices about dataset, representation, and modeling would have resulted in a model that would have made a vastly different prediction.
To some extent choices can be partially justified (typically because one model worked better than another on some held-out data), but this is not the same as being able to claim that one has discovered the “true” model in any meaningful sense. Nor does accuracy necessarily tell the full story. Even models that obtain relatively high accuracy on held out data can still be biased in problematic ways.
Unfortunately, it may not always be possible to provide a complete understanding, because such an understanding is partially a question about why we ended up with a particular system, and that comes down to the modeling choices that were made and the training data that was used. Particularly in cases where the data is sensitive or private, it does not seem feasible to allow everyone to inspect everyone else’s data in order to understand their own prediction. This seems especially true, given that the whole idea behind the GDPR was to give people more control over their data!
There are no easy answers here, but there have been many interesting proposals, with more coming all the time. For example, two recent papers — datasheets for dataset and algorithmic impact assessments — suggest ways in which we can attempt to protect ourselves against unintended biases or misuse of data.
Ultimately, transparency of some kind seems critical, but exactly what form that will take is still far from certain. The technology defines what is possible, but it comes down to regulation like the GDPR to determine what is permitted and what is required. Now, if only it were easy to interpret the regulation…
Key takeaways:
- Some models are easy for humans to interpret, but this is the exception more than the rule; in general, we should not assume that all models can be represented in a way that is easy for humans to understand, at least not without some loss of fidelity.
- In many ways the more important question than how a model works, is why we ended up with that particular model; ultimately, this will always be the result of the training data that was used, how that data was represented, and the modeling decisions that were made.
- When applying machine learning in social domain, it is especially important to think about the training data being used, and to ask if it may be limited or biased in some way; it is much easier to rule out some feature as irrelevant than it is to know if some critical feature may be missing.
- Finally, remember that the vast majority of supervised machine learning models work by discovering correlations in the data; without further evidence, this should not be interpreted as imply any kind of causal connection between inputs and outputs.