How To Detect Concept Drift With Machine Learning Monitoring

Detect Concept Drift With Machine Learning Monitoring

(source)

What Is Concept Drift

Let us start by asking what is concept drift and why does it matter? The following definition is given in Wikipedia:

In predictive analytics and machine learning, concept drift means that the statistical properties of the target variable, which the model is trying to predict, change over time in unforeseen ways. This causes problems because the predictions become less accurate as time passes.”

Concept drift (a.k.a model drift) is part of the machine learning lifecycle, and it is perhaps the central reason for the need to refresh and retrain ML models. As the incoming data drifts away from the historical data which was used for training, the relationships and correlations between features changes as well. Mathematically, concept drift is defined as a change in the distribution P(y|X), where y is the real label, and X are the available features. The problem with concept drift is that the core assumption of Machine Learning is that the training distribution reflects the “real-world” distribution, otherwise nothing ensures that the trained model is fit for the target task.

To bring this idea down to earth, let’s think of an example form the the cybersecurity field. Imagine you are trying to develop a system that notifies you upon potential Denial-of-Service attacks. Perhaps your model will use a feature such as the number of requests received by the server per minute. Maybe when this model was trained, 1000 requests a minute seemed like an extremely large number of requests that may have indicated something fishy was going on. But what if your company launched an advertising campaign after which your website became much more popular. This is an example for when the concept of suspicious behavior has drifted due to changes in reality.

Detecting concept drift early on is essential for maintaining up-to-date models in production that continuously provide value to your company. Ideally, this should be incorporated as part of a robust framework for monitoring ML models in production.

How to Detect Concept Drift

In order to detect concept drift we begin by selecting an appropriate drift detection algorithm. For streaming data, a popular choice is ADWIN (ADaptive WINdowing), while for batched data some popular choices are the Kolmogorov–Smirnov test, the chi-squared test or adversarial validation.

Next, we apply the selected algorithm separately to the labels, the model’s predictions, and the data features. Drift in any one of these categories may be significant in its own way. Drift in the labels, which is known as label drift indicates that there has been a change in the representation of these classes in the real world or your sampling or processing method, and possibly also a concept drift. Similarly, drift in your model’s predictions indicates a data drift in important features, and also perhaps a concept drift. Finally, drift in any of the features individually is worth noting as well, but for some features this may not have a strong effect on your model’s quality. In mathematical notation we have:

To sum it up, during training, our model learns to simulate P(y|X) and thus concept drift by definition implies that our model will not be fit for the task. Label drift, prediction drift and data drift are metrics that are easier to measure directly and may be strong indicators of concept drift.

 

Example

We will show a basic example for detecting concept drift with the ADaptive WINdowing (ADWIN) algorithm, using the river Python library for online ML. We begin by defining three different distributions for the data which we then concatenate to reflect a signal that drifts over time. Think of the data as being the true labels, predictions, or individual features.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

# Generate data for 3 distributions
random_state = np.random.RandomState(seed=42)
dist_a = random_state.normal(0.8, 0.05, 1000)
dist_b = random_state.normal(0.4, 0.02, 1000)
dist_c = random_state.normal(0.6, 0.1, 1000)

# Concatenate data to simulate a data stream with 2 drifts
stream = np.concatenate((dist_a, dist_b, dist_c))

Next, we plot the data:

# Auxiliary function to plot the data
def plot_data(dist_a, dist_b, dist_c, drifts=None):
   fig = plt.figure(figsize=(7,3), tight_layout=True)
   gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
   ax1, ax2 = plt.subplot(gs[0]), plt.subplot(gs[1])
   ax1.grid()
   ax1.plot(stream, label='Stream')
   ax2.grid(axis='y')
   ax2.hist(dist_a, label=r'$dist_a$')
   ax2.hist(dist_b, label=r'$dist_b$')
   ax2.hist(dist_c, label=r'$dist_c$')
   if drifts is not None:
       for drift_detected in drifts:
           ax1.axvline(drift_detected, color='red')
   plt.show()

plot_data(dist_a, dist_b, dist_c)

Which results in the following graph:

On the left is the synthetic signal, while on the right are the histograms for sets drawn from each of the three distribution

On the left is the synthetic signal, while on the right are the histograms for sets drawn from each of the three distributions. As we can see the signal has two points with significant drift.

 

Finally, we try to detect the drift using the ADWIN algorithm.

from river import drift
drift_detector = drift.ADWIN()
drifts = []

for i, val in enumerate(stream):
   drift_detector.update(val)   # Data is processed one sample at a time
   if drift_detector.change_detected:
       # The drift detector indicates after each sample if there is a drift in the data
       print(f'Change detected at index {i}')
       drifts.append(i)
       drift_detector.reset()   # As a best practice, we reset the detector

plot_data(dist_a, dist_b, dist_c, drifts)

Output:
Change detected at index 1055
Change detected at index 2079

As we can see, the algorithm detected two shifting points that are pretty close to the actual points where the drift occurs. 

Conclusion

Concept drift or ML model drift is a common issue with machine learning models in production that is often not dealt with properly. Incorporating some basic monitoring mechanisms as we’ve seen can help you detect potential errors early on, and keep your models fresh and relevant. Deepchecks offers services to assist with this process, enabling your data science team to focus more on researching new exciting problems.

Further Reading

Concept drift detection

Model drift in ML monitoring

Sethi, Tegjyot Singh, and Mehmed Kantardzic. “On the Reliable Detection of Concept Drift from Streaming Unlabeled Data.” Expert Systems with Applications (2017).

vZliobait.e, Indr.e. “Learning under Concept Drift: an Overview.” (2010).

 

 

Subscribe to our newsletter

Do you want to stay informed?
Keep up-to-date with industry news, the latest trends in MLOps, and observability of ML systems.

Subscribe to our newsletter:

Related articles

Annotated Datasets for AI And Machine Learning
What is the Importance of Annotated Datasets for AI And Machine Learning?

Accuracy of Your Machine Learning Model
How to Check the Accuracy of Your Machine Learning Model

Test Your Machine Learning Models
How to Test Machine Learning Models