# Bayesian Teaching as Model Explanation: An MNIST Example

When folks show off their classification models at conference meetings, they typically present examples their models classified correctly and incorrectly. Proud authors will further remark that examples their models misclassify seem challenging for people as well, though such comments are often post-hoc rationalizations. Nevertheless, in both cases, the examples are subjectively cherry picked to some extent for presentation.

One intuitive solution is to use model certainty estimates because there are at least four sorts of examples we care about:

1. Those classified correctly with high confidence,
2. those classified correctly with low confidence,
3. those classified incorrectly with low confidence, and
4. those classified incorrectly with high confidence.

That is, in addition to knowing when our models are right and wrong, we also want to know when they are overconfident and underconfident. However, in many cases, models used in practice do not offer interpretable certainty estimates, if at all, and in laboratory settings, models are improved by studying the inference processes that lead to overcertainty and undercertainty in the first place.

Bayesian teaching is an alternative, more interpretable, more relevant, and more generalizable approach to finding and ranking these sorts of examples.


# Bayesian Teaching

The goal of Bayesian teaching is to induce a target model in the learner by presenting teaching sets of data. This requires tracking two kinds of inference: the teacher’s inference, which is done on the space of possible teaching sets, and the learner’s inference, which is done on the space of possible target models.

Formally, a Bayesian teacher, with a set of training data $D = \{d_1, d_2, \dots, d_N\}$ and a teaching set size n < N, teaches a target model $\targetModel$ by sampling a teaching set $\tsTarget \subseteq D$ from $\allts = \{\ts \given \ts \in \mathscr{P}(D) \land \vert \ts \vert \leq n\},$ according to

\begin{align} p(\getTS{T} \given \targetModel) &= \frac{\pLPrior{T} \ \pLLikelihood{T}} {p(\targetModel)} \\ &= \frac{\pLPrior{T} \ \pLLikelihood{T}} {\sum\limits_{\ts \in \allts} \pLPrior{} \ \pLLikelihood{}}, \end{align}

where $\mathscr{P}(D)$ is the power set of $D$ and $\allts$ is the space of teaching sets. $\pLLikelihood{}$ is the probability the learner will infer the target model $\targetModel$ given a particular teaching set $\ts$ (i.e. the learner’s posterior probability given that teaching set), and $\pLPrior{}$ is the teacher’s prior probability on the same teaching set $\ts$. For computationally efficent learning, one might select priors that assign higher probababilities to smaller teaching sets.

# Bayesian Teaching as Model Explanation

The idea of using Bayesian teaching for probabilistic model explanation was conceived by Scott Cheng-Hsin Yang. The intuition is that subsets of training data that lead a model to the same (or approximately similar) inference as the model trained on all the data should be useful for understanding the fitted model.

In my opinion, there are several important benefits of this approach versus using certainty estimates.

1. Model certainty can be just as mysterious as model predictions, and such numbers are not informative if we do not understand what they represent. Teaching probabilities, however, are more intuitive in that they are computed with some objective on the actual parameter values—not just on the data point.
2. Since all useful models are trained using some sort of update rule, we always have access to the learner’s inference process. This means we should be able to hack a teaching solution together, even when the learning model is not probabilistic. How well this works for non-probabilistic models is, of course, an open question.

Using Bayesian teaching for model explanation is conceptually simple — just pick a teaching set size (e.g. 2) to constrain the search space, perform teaching inference for the category to be understood, and then rank the teaching sets based on the teaching probabilities. To make the search even more straightforward, we can focus on $\vert \mathcal{D} \vert = n$ instead of $\vert \mathcal{D} \vert \leq n$.

I think presenting actual images will be more persuasive here, so I am leaving details for future posts and papers. If you are really hungry for details, you can check out a recent conference paper I submitted with Wai Keen Vong, Anderson Reyes, Scott Cheng-Hsin Yang, and Patrick Shafto on Bayesian teaching of image categories. This paper as well as a number of other papers on Bayesian teaching can be found on our lab website. This was the first pass we made at the work, but we will be submitting more interesting findings in the coming months.

# Concrete examples with MNIST

I used the MNIST handwritten digits dataset because it is relatively well known. To obtain the results below, I:

1. I created two of the same learning model,
2. used the MNIST training set to train one of them, which I call the target model,
3. and then conveyed the target model to the untrained model by performing Bayesian teaching with teaching set size set to 2.

The results below are exactly the sorts of images we are interested in tracking down for conference papers and meetings, but here they are sorted according to a more interpretable metric.

Fig. 1: 5 best teaching sets using ground truth labels. Each pair of images represents a teaching set, and the pairs are sorted in descending order, from left to right, according to the teaching probabilities (i.e. leftmost is set best).

Fig. 2: 5 worst teaching sets using ground truth labels. Each pair of images represents a teaching set, and the pairs are sorted in ascending order, from left to right, according to the teaching probabilities (i.e. leftmost set is worst).

For example, the images in Fig. 1 and Fig. 2 give us a good idea of what the model finds to be the best and worst representations of the number 0, respectively. Two characteristics stand out in particular: first, the worst teaching sets are all high-weight (i.e. bold) handwriting, and second, the worst 0s are sloppy, with little space in the middle of the zero. Note that each teaching set is unique, even though some particular images within sets tend to be present in several teaching sets.

To obtain teaching analogs of high and low certainty misclassifications, I perform teaching in the same way after relabeling the data according to the fitted model’s predictions.

Fig. 3: 5 best teaching sets using model predictions as labels. Again, each pair of images represents a teaching set, and the pairs are sorted in descending order, from left to right, according to the teaching probabilities (i.e. leftmost set is best).

Fig. 4: 5 worst teaching sets using model predictions as labels. Again, each pair of images represents a teaching set, and the pairs are sorted in ascending order, from left to right, according to the teaching probabilities (i.e. leftmost set is worst).

There are a few things to be said here. First, the model generally classifies 0s correctly. Test and cross validation performance of the learning model I used was ~87%, so this is not surprising; all that this means is that you need to look further down the list before you encounter misclassifications. What is interesting is that the worst teaching sets (analogous to low certainty examples) contain 2s that one might rationalize as being very "0-like".

For completeness, I have uploaded teaching sets for the other image categories at the end of this post. Again, note that each teaching set is unique, even though some particular images appear across several teaching sets.

# Concluding Remarks

• These examples are obtained via Bayesian teaching of a target model by using information about the learner’s inference process.
• The ranking here is also interpretable in that we know that the better the teaching set, the more likely it is to induce the model of interest.
• In future posts, I will try to show how one might apply this approach to non-probabilistic models. Wherever it works, Bayesian teaching can buy us examples we care about, sorted in ways we care about.

# Sets Obtained via Bayesian Teaching for Categories 1-9

Click the arrows to display or collapse the teaching sets.

Teaching sets using ground truth labels: Category 1

Teaching sets using model predictions: Category 1

Teaching sets using ground truth labels: Category 2

Teaching sets using model predictions: Category 2

Teaching sets using ground truth labels: Category 3

Teaching sets using model predictions: Category 3

Teaching sets using ground truth labels: Category 4

Teaching sets using model predictions: Category 4

Teaching sets using ground truth labels: Category 5

Teaching sets using model predictions: Category 5

Teaching sets using ground truth labels: Category 6

Teaching sets using model predictions: Category 6

Teaching sets using ground truth labels: Category 7

Teaching sets using model predictions: Category 7

Teaching sets using ground truth labels: Category 8

Teaching sets using model predictions: Category 8

Teaching sets using ground truth labels: Category 9

Teaching sets using model predictions: Category 9