If you like what we're working on, please  star us on GitHub. This enables us to continue to give back to the community.

Machine Learning Testing Principles: Making Sure Your Model Does What it Should

The popularity surrounding Machine Learning (ML) models is credited to the important role they play in a variety of industries. However, classic practices for testing and QA of software systems are not a perfect fit in ensuring our ML models perform the way we believe they should.

In this post, we try to cover some of the basics for automated testing of your ML models in a way that may help save time and effort on debugging your model when it might already be causing significant damage. We will focus on tests that can be applied to a model at the post-training stage, but do still read up on related subjects such as monitoring your model in production, and validating your data.

Classic Unit Testing Does Not Apply to ML Models

One reason it can be hard to test ML models is that they are complicated. Classic software systems can be decomposed into simple units where each unit has a well-defined task for which we may ensure correct performance with a small number of tests. On the other hand, an ML model is a somewhat large unit – the product of a process of training – which cannot be decomposed.

Another reason is that in ML testing, we are trying to test something that is inherently non-deterministic and probabilistic. Our model can make mistakes sometimes, yet still be the best possible model. In classical software, there is no tolerance for any incorrect output, so we can test for such errors in a more straightforward manner.

There is a need to define a methodology for QA of ML systems. We will discuss how we can test our models to ensure they fulfill what is required of them. We will divide the different methods into black-box testing and white-box testing.

Many of the technical terms and ideas are borrowed from the paper:

Beyond Accuracy: Behavioral Testing of NLP Models with CheckList, and from here.

Black Box Testing

General Evaluation

The most widely used tests for ML systems are the evaluation metrics (accuracy, precision, F1, AUC) on the test data. These metrics are important and if we can maintain high results in production, our model is probably doing what it’s meant to do. If, on the other hand, the results deteriorate, then we will know something is wrong.

However, these results paint a very high-level picture. They do not give us information about what may cause the problem. Additionally, evaluating on the full test set can be computationally expensive, so we might want to run this test less frequently than some other less expensive tests.

General metrics are necessary but not sufficient for understanding our model’s performance. (source)

Manual Error Analysis

In contrast to the high-level picture of our overall performance that is provided from the evaluation metrics, it is important to dive into the details and sample some examples that our model gets wrong and investigate why our model predicts the wrong value for these examples.

After going through multiple examples, try to detect patterns:
Are the errors similar to each other?
Why is the model making mistakes?
Are there issues with the input data?
Would a human make a similar mistake?
Answering these questions is essential to get a deep understanding of your model’s capabilities.

The confusion matrix can give us an indication of the most common errors (source)

It can be helpful to cluster the data points and visualize the data (i.e., using PCA or t-SNE). This can help you know if your model performs poorly on a certain region of the data.

visualization of clusters of the MNIST dataset using t-SNE (source)

Note that some digits can look similar to others, and our model is more likely to make mistakes on the border between the “1” cluster and the “7” cluster for example

Although this post is focused on automated tests, we recommend running at least one iteration of this manual stage since familiarity and direct contact with your model and data is essential for creating quality ML models.

Important: Error analysis should be done on a validation set and not on the hold-out set, in order to prevent data leakage.

Open source package for ml validation

Build Test Suites for ML Models & Data with Deepchecks

Get StartedOur GithubOur Github

Naive Single Prediction Tests

Perhaps the most straightforward analogy to unit testing for ML model testing is to provide a sample that you require the model to predict correctly and assure that the prediction is correct. For example, if we created a model for sentiment analysis of movie reviews we would expect the text “…one of the best movies of 2010” to be classified as positive. For such a test, we should choose an example that is “easy” and non-ambiguous, an example where if our model gets it wrong, it shouldn’t be in production.

The problem with such tests is, as previously mentioned, that ML models are allowed to make mistakes. What we usually care about is the global performance and not the prediction of a single example. And so, if we wrote such a test, our model might fail after retraining for no particular reason. That is why we recommend using “easy” examples that you really shouldn’t get wrong.

Example code for running with pytest:

def test_negative_sentiment(model):
   text = "This movie was a total waste of time"
   sentiment = model.predict(text)
   assert sentiment == 0


============================= test session starts =============================
test_sentiment_analysis_model.py::test_negative_sentiment PASSED                                     [100%]
============================== 1 passed in 0.01s ==============================

One slightly more sophisticated option is creating templates for simple examples and then evaluating the results of the model on many examples that match this template. This is referred to as the Minimum Functionality Test (Ribeiro et al.).

Directional Expectation Tests

This test enables us to define the expected effects of some data perturbations on the output. For example, when trying to predict the value of a house, we would want to assert that our model predicts a higher price for a larger house when other attributes remain constant.

This kind of test is useful because it is comparative rather than absolute. This allows us to automatically create tests from templates without needing to define the expected prediction value for any example so we can easily create tests with wide coverage.

For sentiment analysis, we want to ensure that adding positive content makes the prediction more positive, and vice versa for negative content (Ribeiro et al.)

Invariance Tests

We may define data perturbations that should not affect the model output. This resembles data augmentation, however, here we discuss it as a means of testing our model. For example, in an NLP task, swapping a word with its synonym should not affect the output dramatically. Similarly, when creating a “fair” model that is meant to be blind to attributes such as gender or race, we can test to see that changing these features does not affect the prediction in a meaningful way.

Augmenting images of a dog should be classified as “dog” as well (source)

Example: Swapping the destination should not affect the sentiment prediction (Ribeiro et al.)

Evaluation of Data Slices

As discussed earlier, general evaluation gives us a very high-level picture of our results. How can we get a more fine-grained understanding of our results on different types of examples?

The idea is fairly simple – use data slicing to create many different subsets of your dataset, and then evaluate each subset separately. After gathering this information, we can then investigate slices with low performance and understand the underlying causes for non-optimal results. This is a relatively inexpensive process that can help you significantly increase performance.

Detecting critical subsets can help you improve your model quality quickly (source)

We recommend checking out the snorkel library that has features for customizable data slicing and evaluating on each slice. We also recommend checking “Slice Finder: Automated Data Slicing for Model Validation,” and the corresponding github repository for automatic selection of slices with poor performance.

Automatic detection of slices with bad performance using “Slice Finder” (source)

White Box Testing

Explainability

Explainable ML systems give us both the what and the why. They provide an explanation for their predictions. The definition of an “explanation” might be a philosophical question, but there are some commonly accepted definitions that are used in practice. For example, a decision tree is considered to be self-explanatory since we understand the process of the prediction by traversing the edges of the tree that correspond to the different conditions.

Decision trees are “self-explanatory” models (source)

In Computer Vision, a common practice is to create a heatmap of the pixels that affect the prediction the most as an “explanation.”

The red pixels have the largest effect on the classification of the image as a “cat”, and therefore provide an explanation for the prediction (source)

For models that are less explainable (such as DNNs), we can try to use knowledge distillation to provide an equivalent model that is explainable, or we can train our model to output an explanation to go along with the prediction.

Why is this important?

It is important to understand the inner logic of the model and how different features affect the prediction. This prepares us for possible future bugs and helps us make more robust ML models.

Examining Weights During Training

When we see that our model is learning (the loss is decreasing and the accuracy is increasing), we are satisfied. But there are some insights that can be derived from examining the weights of the network during training that are worth checking out.

Try to detect the bug in the following piece of code for example:

def make_convnet(input_image):
   net = slim.conv2d(input_image, 32, [11, 11], scope="conv1_11x11")
   net = slim.conv2d(input_image, 64, [5, 5], scope="conv2_5x5")
   net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')
   net = slim.conv2d(input_image, 64, [5, 5], scope="conv3_5x5")
   net = slim.conv2d(input_image, 128, [3, 3], scope="conv4_3x3")
   net = slim.max_pool2d(net, [2, 2], scope='pool2')
   net = slim.conv2d(input_image, 128, [3, 3], scope="conv5_3x3")
   net = slim.max_pool2d(net, [2, 2], scope='pool3')
   net = slim.conv2d(input_image, 32, [1, 1], scope="conv6_1x1")
   return net

Perhaps you noticed that the layers are not really stacked, since we use input_image as the input to each layer. But how would we detect this error? The code would run smoothly and even the learning process would seem to work, but we would get poor results since, in practice, there is only a single convolutional layer in our network.

If we inspect the weights of the network after each epoch, we would see that most of them remain at the exact same state.

def test_convnet():
   image = tf.placeholder(tf.float32, (None, 100, 100, 3)
   model = Model(image)
   sess = tf.Session()
   sess.run(tf.global_variables_initializer())
   before = sess.run(tf.trainable_variables())
   _ = sess.run(model.train, feed_dict={
   image: np.ones((1, 100, 100, 3)),
   })
   after = sess.run(tf.trainable_variables())
   for b, a, n in zip(before, after):
   # Make sure something changed.
   assert (b != a).any()

Example test code to make sure that the weights change after training

Conclusion

Testing is an essential part of any software development process, however, in the field of M,L there is still no standard practice that is widely accepted. We have shown some practices that enable you to test your ML models and make sure they do what you expect of them. We hope you enjoyed reading this post, and let us know if you have any thoughts on the subject.

Subscribe to Our Newsletter

Do you want to stay informed? Keep up-to-date with industry news, the latest trends in MLOps, and observability of ML systems.

Related articles

How to Choose the Right Metrics to Analyze Model Data Drift
How to Choose the Right Metrics to Analyze Model Data Drift
What to Look for in an AI Governance Solution
What to Look for in an AI Governance Solution
×

Event
Identifying and Preventing Key ML PitfallsDec 5th, 2022    06:00 PM PST

Days
:
Hours
:
Minutes
:
Seconds
Register NowRegister Now