Deep Weighted Averaging Classifiers
In a new paper to appear at the ACM conference on Fairness, Accountability, and Transparency (FAT* 2019), we present a method called deep weighted averaging classifiers (DWACs), which can transform any deep learning architecture into one that is more transparent, interpretable, and robust to out-of-domain data.
Although deep learning models differ in their details, the vast majority of deep architectures for classification include a final softmax layer, which projects a hidden representation into a vector of probabilities. That is, the output is computed as:
where h=g(x) is the output vector from the preceding layers, such as a convolutional or recurrent network.
The alternative that we propose is to keep g(x) the same, drop the softmax, and instead compute the probability of each label as a weighted sum of training instances, where the weights are based on the distance between instances in the low-dimensional instance-embedding space learned by the model:
where w(h, ⋅) ∈ [0, 1] is a weight computed using a static Gaussian kernel operating on the distance between the learned embeddings of a test instance and a training instance, indicating how similar the two instances are, according to the model.
It turns out, we can learn this model just as easily as the softmax model using stochastic gradient descent, with no loss in accuracy or calibration, and only a slight increase in training time.
There are three main advantages to this approach:
First, predictions from the model are now explicitly computed as a weighted sum of training instance, where the weight captures the model’s judgement about the degree of similarity between a pair of instances. Rather than just getting a single number from the model (the predicted probability), we can unpack any prediction in terms of the weight being placed on each training instance.
For example, a DWAC model trained on the Fashion MNIST dataset finds that there are many very similar training instances to the boot shown in the figure below (with weights shown below the four most similar training images), hence we have good reason to trust this prediction:
By contrast, for these unusual “trousers”, even the nearest training instances in the learned embedding space are quite far away (and hence have low weights):
This is a strong indication that we should be skeptical of the model’s prediction on this instance.
Second, for multiclass classification, if we use a model of the same size as an equivalent softmax model, we can get the same performance in terms of both accuracy and calibration. However, if we want to be able to visualize the learned embeddings, we can choose to use a two-dimensional output layer (instead of dimensionality equal to the number of classes), and still get relatively good accuracy.
This allows us to automatically visualize how the model is embedding the data, without needing an additional step of dimensionality reduction such as PCA or t-SNE. For the Fashion MNIST data, we find that the resulting embeddings have a pleasing semantics, with related types of clothing occurring close together, as shown below.
Finally, if we want to quantify how much we should trust a particular prediction, we can make use of ideas from conformal methods. The details are slightly complicated, but essentially we can choose a measure of nonconformity, which indicates how unusual a new instance would be if we tried giving it each possible label.
Using a standard measure of nonconformity (based on predicted probability) already works quite well to detect out of domain data, but our method suggests a new measure of nonconformity, namely the negative unnormalized weighted sum of training instances with the same label:
Intuitively, this corresponds to how close the nearest training points (with a particular label) are to where the test point is being embedded. If a new instance is placed far from all of the training data, this is a signal that we may not be able to trust the model on this instance.
Using this method, we find that we can even more reliably detect out of domain data. For example, if we train a model on the Fashion MNIST dataset, and then ask the model to make predictions on the original MNIST digits, any model we try will still predict one of the fashion classes for each digit, often with high probability (because that is all the model knows). However, if we look at the credibility values from a conformal predictor, we find that they tend to be quite low, and are shifted even closer to zero when using our proposed measure of nonconformity.
The two plots below show the credibility values on out-of-domain data (the MNIST digits, for models trained on Fashion MNIST) using a softmax model with a standard measure of nonconformity (top), and a DWAC model using our measure of nonconformity (bottom).
As you can see, our model is highly confident that the MNIST digits are not at all similar to the original training data, and hence we should reject the model’s predictions on these instances.
There is much more to be said about conformal methods, including the ability to upper bound the error rate of a model, and we provide more details in the paper.
In conclusion, DWACs offer a way to make deep models more transparent, interpretable, and robust, with no loss of accuracy or calibration. We’ve illustrated this with image classification here, but as we show in the paper, the same approach works equally well for textual and tabular data.
Reference: Dallas Card, Michael Zhang, and Noah A. Smith. Deep Weighted Averaging Classifiers. In Proceedings of FAT* (2019).