Bayesian Teaching as Model Explanation: An MNIST Example

  10 mins read  

When folks show off their classification models at conference meetings, they typically present examples their models classified correctly and incorrectly. Proud authors will further remark that examples their models misclassify seem challenging for people as well, though such comments are often post-hoc rationalizations. Nevertheless, in both cases, the examples are subjectively cherry picked to some extent for presentation.

One intuitive solution is to use model certainty estimates because there are at least four sorts of examples we care about:

  1. Those classified correctly with high confidence,
  2. those classified correctly with low confidence,
  3. those classified incorrectly with low confidence, and
  4. those classified incorrectly with high confidence.

That is, in addition to knowing when our models are right and wrong, we also want to know when they are overconfident and underconfident. However, in many cases, models used in practice do not offer interpretable certainty estimates, if at all, and in laboratory settings, models are improved by studying the inference processes that lead to overcertainty and undercertainty in the first place.

Bayesian teaching is an alternative, more interpretable, more relevant, and more generalizable approach to finding and ranking these sorts of examples.

$$ \newcommand\given[1][]{\:#1\vert\:} \newcommand\getTS[1]{\mathcal{D}_{#1}} \newcommand\ts{\getTS{}} \newcommand\tsTarget{\getTS{T}} \newcommand\allts{\mathfrak{D}} \newcommand\targetModel{\Theta^*} \newcommand\pLPrior[1]{p(\getTS{#1})} \newcommand\pLLikelihood[1]{p_L(\targetModel \given \getTS{#1})} $$

Bayesian Teaching

The goal of Bayesian teaching is to induce a target model in the learner by presenting teaching sets of data. This requires tracking two kinds of inference: the teacher’s inference, which is done on the space of possible teaching sets, and the learner’s inference, which is done on the space of possible target models.

Formally, a Bayesian teacher, with a set of training data and a teaching set size n < N, teaches a target model by sampling a teaching set from according to

$$ \begin{align} p(\getTS{T} \given \targetModel) &= \frac{\pLPrior{T} \ \pLLikelihood{T}} {p(\targetModel)} \\ &= \frac{\pLPrior{T} \ \pLLikelihood{T}} {\sum\limits_{\ts \in \allts} \pLPrior{} \ \pLLikelihood{}}, \end{align} $$

where is the power set of and is the space of teaching sets. is the probability the learner will infer the target model given a particular teaching set (i.e. the learner’s posterior probability given that teaching set), and is the teacher’s prior probability on the same teaching set . For computationally efficent learning, one might select priors that assign higher probababilities to smaller teaching sets.

Bayesian Teaching as Model Explanation

The idea of using Bayesian teaching for probabilistic model explanation was conceived by Scott Cheng-Hsin Yang. The intuition is that subsets of training data that lead a model to the same (or approximately similar) inference as the model trained on all the data should be useful for understanding the fitted model.

In my opinion, there are several important benefits of this approach versus using certainty estimates.

  1. Model certainty can be just as mysterious as model predictions, and such numbers are not informative if we do not understand what they represent. Teaching probabilities, however, are more intuitive in that they are computed with some objective on the actual parameter values—not just on the data point.
  2. Since all useful models are trained using some sort of update rule, we always have access to the learner’s inference process. This means we should be able to hack a teaching solution together, even when the learning model is not probabilistic. How well this works for non-probabilistic models is, of course, an open question.

Using Bayesian teaching for model explanation is conceptually simple — just pick a teaching set size (e.g. 2) to constrain the search space, perform teaching inference for the category to be understood, and then rank the teaching sets based on the teaching probabilities. To make the search even more straightforward, we can focus on instead of .

I think presenting actual images will be more persuasive here, so I am leaving details for future posts and papers. If you are really hungry for details, you can check out a recent conference paper I submitted with Wai Keen Vong, Anderson Reyes, Scott Cheng-Hsin Yang, and Patrick Shafto on Bayesian teaching of image categories. This paper as well as a number of other papers on Bayesian teaching can be found on our lab website. This was the first pass we made at the work, but we will be submitting more interesting findings in the coming months.

Concrete examples with MNIST

I used the MNIST handwritten digits dataset because it is relatively well known. To obtain the results below, I:

  1. I created two of the same learning model,
  2. used the MNIST training set to train one of them, which I call the target model,
  3. and then conveyed the target model to the untrained model by performing Bayesian teaching with teaching set size set to 2.

The results below are exactly the sorts of images we are interested in tracking down for conference papers and meetings, but here they are sorted according to a more interpretable metric.

figure_1

Fig. 1: 5 best teaching sets using ground truth labels. Each pair of images represents a teaching set, and the pairs are sorted in descending order, from left to right, according to the teaching probabilities (i.e. leftmost is set best).

figure_2

Fig. 2: 5 worst teaching sets using ground truth labels. Each pair of images represents a teaching set, and the pairs are sorted in ascending order, from left to right, according to the teaching probabilities (i.e. leftmost set is worst).

For example, the images in Fig. 1 and Fig. 2 give us a good idea of what the model finds to be the best and worst representations of the number 0, respectively. Two characteristics stand out in particular: first, the worst teaching sets are all high-weight (i.e. bold) handwriting, and second, the worst 0s are sloppy, with little space in the middle of the zero. Note that each teaching set is unique, even though some particular images within sets tend to be present in several teaching sets.

To obtain teaching analogs of high and low certainty misclassifications, I perform teaching in the same way after relabeling the data according to the fitted model’s predictions.

figure_3

Fig. 3: 5 best teaching sets using model predictions as labels. Again, each pair of images represents a teaching set, and the pairs are sorted in descending order, from left to right, according to the teaching probabilities (i.e. leftmost set is best).

figure_4

Fig. 4: 5 worst teaching sets using model predictions as labels. Again, each pair of images represents a teaching set, and the pairs are sorted in ascending order, from left to right, according to the teaching probabilities (i.e. leftmost set is worst).

There are a few things to be said here. First, the model generally classifies 0s correctly. Test and cross validation performance of the learning model I used was ~87%, so this is not surprising; all that this means is that you need to look further down the list before you encounter misclassifications. What is interesting is that the worst teaching sets (analogous to low certainty examples) contain 2s that one might rationalize as being very "0-like".

For completeness, I have uploaded teaching sets for the other image categories at the end of this post. Again, note that each teaching set is unique, even though some particular images appear across several teaching sets.

Concluding Remarks

  • These examples are obtained via Bayesian teaching of a target model by using information about the learner’s inference process.
  • The ranking here is also interpretable in that we know that the better the teaching set, the more likely it is to induce the model of interest.
  • In future posts, I will try to show how one might apply this approach to non-probabilistic models. Wherever it works, Bayesian teaching can buy us examples we care about, sorted in ways we care about.


Sets Obtained via Bayesian Teaching for Categories 1-9

Click the arrows to display or collapse the teaching sets.

Teaching sets using ground truth labels: Category 1

figure_5 figure_6

Teaching sets using model predictions: Category 1

figure_7 figure_8

Teaching sets using ground truth labels: Category 2

figure_9 figure_10

Teaching sets using model predictions: Category 2

figure_11 figure_12

Teaching sets using ground truth labels: Category 3

figure_13 figure_14

Teaching sets using model predictions: Category 3

figure_15 figure_16

Teaching sets using ground truth labels: Category 4

figure_17 figure_18

Teaching sets using model predictions: Category 4

figure_19 figure_20

Teaching sets using ground truth labels: Category 5

figure_21 figure_22

Teaching sets using model predictions: Category 5

figure_23 figure_24

Teaching sets using ground truth labels: Category 6

figure_25 figure_26

Teaching sets using model predictions: Category 6

figure_27 figure_28

Teaching sets using ground truth labels: Category 7

figure_29 figure_30

Teaching sets using model predictions: Category 7

figure_31 figure_32

Teaching sets using ground truth labels: Category 8

figure_33 figure_34

Teaching sets using model predictions: Category 8

figure_35 figure_36

Teaching sets using ground truth labels: Category 9

figure_37 figure_38

Teaching sets using model predictions: Category 9

figure_39 figure_40