Most businesses these days rely on Machine Learning (ML) predictions and results because, over the years, ML models have proven to be desirable solutions for making business decisions. These ML models are built using data and state-of-the-art ML algorithms with good resources and then deployed to production where they interact with users and get new real-time data.
These models need constant monitoring while in production because the data these models are fed can change at any given time. Those could lead to inaccurate predictions that disrupt the model’s performance. These changes in data distribution are what we call Data Drift.
Data drift bears problematic consequences for businesses and user experience, so ML engineers and data scientists have to employ means to detect these drifts and properly manage them using data drift thresholding methods. In this article, we discuss data drift – its causes, types, how to detect them, data drift thresholding, and the tools for implementing automated data drift detection in production-level ML models.
Data drift is one of the most common issues in production-level ML models. It is caused by changes in the data (used to train the model) that render it unreliable. The factors that cause this include changes in business requirements, unreliable input data (e.g., changes in user behavior), and high workloads.
Data drift generally occurs as a result of periodic changes like consumer preferences, weather conditions, and political situations. Academic data collated pre-COVID, for example, shows a lesser preference for online learning than during COVID. Likewise, the demand for hand sanitizers increased considerably throughout COVID. Models trained on previous data would be unreliable. The change in input data during production disrupts the model’s performance since the data distribution also changes.
Data drift can also occur when the model has been trained on certain data but is exposed to a wider, less specific range of data during production. This is especially true for spam detection. If the data used in training does not include a wide range of examples of spam emails, then it is likely to misidentify spam mails as primary, upon deployment.
Basically, data drift is a result of the unavoidable significance of “under learned” data inputs. The examples we used can be categorized into one of the following:
- Sample Selection Bias. This occurs when there’s a systematic flaw in data collection, data integration, labeling, or sampling.
- Non-stationary Environments / Events: This occurs when the training and testing conditions vary, like using the model in an unfamiliar environment, or changes in the economic environment (e.g., the Great Recession of 2007; COVID-19)
- Upstream Data Transformation. These are changes that occur due to upstream data processing issues such as a faulty data pipeline, code error, data source changes, etc. All these can potentially affect feature value distributions.
Data drift falls under two types:
- Feature Shift
- Covariate Shift
- Feature Shift. This type of data drift happens when the features of a dataset used during training change or, in most times, become redundant. The feature becomes irrelevant as time passes and would no longer contribute to the target feature and model performance.
- Covariate Shift. This typically happens when there is a change in the input distribution of independent variables between training data and production data. This is primarily a result of the changes in the state of the smaller, unobservable variables that make up an independent variable(feature). The change affects the independent feature, which in turn affects the data distribution and the model’s performance. They are called Latent Variables and are sometimes temporal.
If a model is trained to detect a particular disease in humans and the data samples used are from a certain type of patient (e.g., age, group, gender) that isn’t representative of the collective real-world use-case, a covariate shift could occur. If the model was trained on a dataset that contained children patients, for example, it won’t give accurate results when used on patients in their 30s.
Data Drift Thresholding
Data drift thresholding is a method of determining the best possible data drift boundary baseline
that is accepted for particular data distribution use cases. The goal is to find the degree of change in training data compared to either the testing data or real-time production data, then set appropriate measures.
Data drift thresholding is a continuous process, especially for models in production because the input data is susceptible to change, and sometimes the goal (business KPI) of the model also changes.
It is basically detecting the data drifts in a data distribution as they occur and determining the range of error.
Detecting and Analyzing Data Drift for Thresholding
Data drift can be detected in different ways. One simple and common approach is the use of statistical tests to compare the distribution of the training data (a.k.a., baseline or reference) and live data (production data). If the difference between two distributions is significant, drift has occurred. Examples of these statistical tests include the Kolmogorov-Smirnov test, Populations Stability Index, Jensen–Shannon divergence, and Wasserstein distance.
Another common approach is to use an ML model that monitors the data quality for distribution changes and then finds the difference between the data points at different moments in time.
Kolmogorov-Smirnov Test (K-S)
The K-S test is one of the most common because it’s a non-parametric test and doesn’t make any assumption about the distributions it compares.
K-S tests compare the cumulative distributions of two data sets, usually the training data and the post-training data. If there is no observable difference or relationship between two samples (null hypothesis) after a KS test, it means that the data distributions from both datasets are the same. If there is a difference, then there is a drift in the model. It is a very efficient way to determine whether or not two distributions are significantly different from each other. It is suitable for performing data drift thresholding for unidimensional data (one feature column at a time).
The Data drift falls under two types is:
n = number of observable data distribution( it’s usually 2); and
x = observed cumulative frequency distribution of a random sample of n observations.
Population Stability Index (PSI)
This is another statistical test that helps us perform data drift thresholding, thanks to its ability to show the difference between two related data distributions.
The PSI is a computed value that shows the changes in the population distribution of a dataset at two different points in time (e.g., during training and in production, or at different times in production). If we have a predictive model that predicts the churn rate for a streaming service when a popular show is on and then proceed with testing it against a sample from when the popular show is done, the model might not be able to predict accurately as the population distribution might have changed significantly. In this case, the churn rate is high, but the model might not be able to capture it; thus, predicting erroneous results. Using PSI, we proceed to check the population distribution shifts between the development time and the current time, so we can get a fair idea if the model results are reliable or not.
PSI can be implemented like this
Actual = the target data distribution; and
Expected = the reference data distribution
STEPS TO CALCULATE PSI
- Sort scoring variables in descending order in the scoring sample.
- Split the data into 10 or 20 groups.
- Calculate % of records in each group based on a scoring sample.
- Calculate % of records in each group based on the development sample.
- Calculate the difference between Step 3 and Step 4.
- Take Natural Log of (Step 3 / Step 4).
- Multiply the result from Step 5 with the result from Step 6.
INTERPRETATION OF PSI THRESHOLD
- PSI < 0.1: No change. You can continue using the existing model.
- PSI >=0.1: but less than 0.2 — Slight change is required.
- PSI >=0.2: Significant change is required. Ideally, we should not use this model anymore. It should be recalibrated/redeveloped.
When the PSI leaps over .2 for a batch, for example, it is meant to alert you to inspect it. Finding a solution may include checking to see if a feeding pipeline is broken, or if a drift really occurred and the model might need retraining. Whatever the solution may be, it still does not account for the extra .2 on the PSI. Therefore, setting randomly selected thresholds is rarely a good solution. If the threshold is too high, alerts that should be raised may likely be ignored, leading to false negatives. And if the threshold is too low, alerts that should not be raised are suddenly present, leading to more false positives.
Jensen-Shannon Divergence (JSD)
This metric is a non-symmetric measure that shows how similar or different two distributions are.
JSD is a smooth version (square root) of the Kullback Leibler (KL) divergence test value. Compared to KL Divergence, the JSD score is easier to interpret because it gives a score between 0 (the distributions are identical) and 1 (the distributions are different) when using log base 2.
Using statistical methods depends on the use case. These thresholds are adjusted according to business relevance, but the idea remains the same – to keep a watch over the population shifts.
Manually calculating this test can be cumbersome, so there is a need to automate it. Moreover, these statistical methods tend to fail for high-dimensional or sparse features (curse of dimensionality i.e., the more the features, the harder it is to measure). But this can be overcome with the use of good computing power, intelligent algorithms, and ML methods.
Data Drift Detection with Machine Learning
There are 3 main stages in detecting data drift in Machine Learning:
- Data Quality Monitoring
- Model Quality Monitoring
- Drift Evaluation
- Data Quality Monitoring. Here, the metadata of the input data is recorded (e.g., data type, size, source) whilst data training is ongoing. This enables repeated comparisons of incoming data with the metadata. That way, even the slightest deviations can be observed in the data metadata that indicates a drift in the input data.
- Model Quality Monitoring. This involves capturing real-time values, which can then be compared with possible outcomes and predictions. A good example is using weekly demand predictions to compare predicted quantities with actual demand one week later.
- Drift Evaluation. This is simply closely monitoring the data and systems to detect changes and prompt subsequent actions (data quality and model quality monitoring).
Automating Data Drift Thresholding
Manually observing drift metrics produced becomes burdensome, and sometimes impractical, over time. The go-to solution for getting drift metrics would be to set alerts based on a certain threshold.
It is imperative to deal with data drifts immediately, in a way that these automated techniques can detect the necessary anomalies (drift) and perform the necessary changes required in the model.
There are several ways to automate data drift thresholding. Here are some:
Online Learning. Online learning is a method of handling data drift where the model or algorithm updates the model in real-time as the data comes in sequentially, as opposed to the traditional ML techniques where the model or algorithm is trained on the entire training data set at once
Online learning is good when it is necessary for the algorithm to dynamically adapt to new patterns in the data, or when the data itself is generated as a function of time (e.g., stock price prediction).
A very reliable way to perform online learning is the creme library based on python.
For data and ML engineering teams, there is a constant need to create flawless production models. One way to achieve this is through the automation of drift thresholding systems. This helps to identify the drifts, determine the thresholds to apply, and influence the decisions to be made upon drift detection.
Data drift thresholding is an important action for models in production. In this article, we discussed what data drift is, its causes, and its types. We also explored the different important statistical methods used and paid special attention to ML approaches and tools for detecting and analyzing data drifts.
We hope this helped. Happy thresholding!