If you like what we're working on, please  star us on GitHub. This enables us to continue to give back to the community.

Testing Machine Learning Models In Your CI/CD Pipeline

Using Airflow and Deepchecks to validate Machine Learning models in the CI/CD pipeline


Testing Machine Learning Models In Your CI/CD Pipeline

Have you ever heard people speaking about their CI/CD pipelines and wondered if there is something like that for Machine Learning? Want to be able to automatically run tests on your ML models every time the data or models change?

In this article, I’m going to explain how you can validate your Machine Learning models with Apache Airflow and the Deepchecks validation package right before your model is built successfully in the CI/CD pipeline. Assuming your model is built on an Airflow pipeline, this article will demonstrate how you can integrate the rich test suites of Deepchecks into the workflows.

What is Apache Airflow?

From the docs:

“Airflow is a platform to programmatically author, schedule, and monitor workflows. Use Airflow to author workflows as Directed Acyclic Graphs (DAGs) of tasks. The Airflow scheduler executes your tasks on an array of workers while following the specified dependencies. Rich command line utilities make performing complex surgeries on DAGs a snap. The rich user interface makes it easy to visualize pipelines running in production, monitor progress, and troubleshoot issues when needed.”

Basically, Airflow is a tool that helps run workflows consisting of multiple stages (a DAG). It is responsible for the scheduling and the orchestration of the DAGs runs. The DAGs are defined with Python code and as a result, the workflows become more maintainable, versionable, testable, and collaborative. For more info about DAGs, visit Airflow Docs.

Defining our DAG

Below is an example DAG definition that validates a model using Airflow and Deepchecks, and its correspondent DAG on the Airflow platform:

with DAG(
            "owner": "airflow",
            "retries": 1,
            "retry_delay": timedelta(minutes=5),
            "start_date": datetime(2021, 1, 1),
) as dag:
    load_adult_dataset = PythonOperator(

    integrity_report = PythonOperator(

    load_adult_model = PythonOperator(

    evaluation_report = PythonOperator(

load_adult_dataset >> integrity_report
load_adult_dataset >> load_adult_model >> evaluation_report

As it can be seen, this DAG defines 2 validation steps:

  • To validate the data – we define a dataset integrity step that is being called after the data is loaded.
  • To evaluate the model – we define a model evaluation step that is being called after the data and the pre-trained model are loaded.

In the next section, we will explain how to declare such steps using Deepchecks, the open-source Python library for ML validation.

Testing. CI/CD. Monitoring.

Because ML systems are more fragile than you think. All based on our open-source core.

Deepchecks HubOur GithubOpen Source

What is Deepchecks?

In short, Deepchecks is an open-source Python library for testing ML/DL models and data. The library can help us out with various testing and validation needs throughout our projects — we can verify the data’s integrity, inspect the distributions, confirm valid data splits (for example, the train/test split), evaluate the performance of our model, and more!

The Deepchecks package contains many different checks – that perform a single check on the data and model (for example, detecting a feature drift between the train and the test data), and suites – which are an ordered collection of checks. The suite object enables displaying a concluding report for all of the checks that ran. Deepchecks already comes with some built-in suites like the data integrity suite, the model evaluation suite, and more. Check the full list of built-in suites here.

Deepchecks Suite of Checks

Deepchecks & Airflow

In this article, we will use Deepchecks in an Airflow workflow in order to validate a model. Our model will be a simple RandomForest model that is trained on the well-known adult dataset.

We will use the integration tutorial provided in the Deepchecks docs. Deepchecks can be used within an Airflow workflow stage and run a suite in order to validate a model.

Validating the Integrity of the Training Data

We can define a workflow stage that will validate the integrity of our training data, using the built-in integrity suite. The below snippet demonstrates just that:

def dataset_integrity_step(**context)
    from deepchecks.tabular.suites import single_dataset_integrity
    from deepchecks.tabular.datasets.classification.adult import _CAT_FEATURES, _target
    from deepchecks.tabular import Dataset

    adult_train = pd.read_csv(context.get("ti").xcom_pull(key="train_path"))

    ds_train = Dataset(adult_train, label=_target, cat_features=_CAT_FEATURES)

    train_results = single_dataset_integrity().run(ds_train)

    except OSError:
        print("Creation of the directory {} failed".format(dir_path))

    run_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    train_results.save_as_html(os.path.join(dir_path, f'train_integrity_{run_time}.html'))

The output of the stage is an HTML report of the suite, which looks like this:

Validating a Model

Now, we will define a workflow stage that validates the performance of a pre-trained model. We will use the model_evaluation suite for that.

def model_evaluation_step(**context)
    from deepchecks.tabular.suites import model_evaluation
    from deepchecks.tabular.datasets.classification.adult import _CAT_FEATURES, _target
    from deepchecks.tabular import Dataset

    adult_model = joblib.load(context.get("ti").xcom_pull(key="adult_model"))
    adult_train = pd.read_csv(context.get("ti").xcom_pull(key="train_path"))
    adult_test = pd.read_csv(context.get("ti").xcom_pull(key="test_path"))
    ds_train = Dataset(adult_train, label=_target, cat_features=_CAT_FEATURES)
    ds_test = Dataset(adult_test, label=_target, cat_features=_CAT_FEATURES)

    evaluation_results = model_evaluation().run(ds_train, ds_test, adult_model)

    run_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    evaluation_results.save_as_html(os.path.join(dir_path, f'model_evaluation_{run_time}.html'))

This will result in the following report:

Wrapping Up

In this short article, we demonstrated an approach to validating ML models and data using Airflow and Deepchecks. We have defined 2 Airflow stages that validate different aspects of the model building pipelines: First, we checked the data for integrity issues, and later evaluated the model performance.

Feel free to use the snippets provided here in your projects and workflows, I can’t wait to hear what it would find!

In order to understand more, and to download the code in the article, please visit the following documentation page in Deepchecks docs.

Testing. CI/CD. Monitoring.

Because ML systems are more fragile than you think. All based on our open-source core.

Deepchecks Hub Our GithubOpen Source