Machine Learning System Testing: A Guide to Writing Unit Tests :robot:

In the upcoming series of blog posts, I’ll be focusing on guiding you through the process of creating tests specifically for machine learning systems. This particular blog will cover the basics of writing unit tests for machine learning functions. We’ll start with a simple introduction to pytest and gradually move on to more advanced concepts, including using fixtures and mocking techniques. This post is perfect for folks who are new to testing ML systems :rocket:.

You can install pytest using the following command,

pip install -U pytest

Next up, make a file called test_functions.py. This file will contain the functions we’re going to talk about in this post.

When you’re ready to test, simply run pytest in your terminal to kick off the unit tests :collision:.

:snake: Pytest basics :snake:

Create a test function to verify that the target function returns the string hello world! as its output.

def greet():
  return "hello world!"

Now, we need to test the above function

def test_greet():
  assert greet() == "hello world!"
collected 1 item                                                                                                                                                                 

test_sample.py .                                                                                                                                                           [100%]
=============================================================================== 1 passed in 0.00s ================================================================================

You should witness the test passing :heavy_check_mark:. Afterward, make a modification to the greeting message within the function and then attempt to run the tests. You should then observe them failing :x:.

A few more tests,

Create a test function to verify the addition of two numbers within the function.

def add(a, b):
  return a+b
def test_add():
  result =  add(1, 4) 
  assert type(result) == int
  assert result == 5

Create a test function to test a function that produces n-grams for a given sentence and a specified value of n.

def get_ngrams(sentence, n):
  tokens  = sentence.split(" ")
  n_grams = []
  for i in range(len(tokens)-n+1):
    n_grams.append(tuple(tokens[i:i+n]))
  return n_grams
def test_get_ngram():
  result = get_ngrams("Good morning world", 2)
  assert len(result) == 2
  assert result == [("Good", "morning"), ("morning", "world")]

  result = get_ngrams("Good morning world", 3)
  assert len(result)== 1
  assert result == [("Good", "morning", "world")]

Create a test function to test a function that calculates the word count for a given sentence.

def get_word_count(sentence):
  words  =  sentence.split(" ")
  word_count = Counter(words)
  return word_count
def test_get_word_count():
  result = get_word_count("Good world world")
  assert result["Good"] == 1
  assert result["world"] == 2

How to improve the quality of your unit tests?

  • Diverse Test Cases: Define a range of test cases that should produce both successful and failed outcomes. This ensures thorough testing of the unit’s behavior.

  • Floating Point Precision: If the function returns floating-point values, pay attention to comparing the precision of the results to ensure reproducibility. Changes in dependent libraries or optimizations might impact the results, so precise comparisons are essential. Check out np.isclose to test the precision based test :eyes:.

    You should do something like this,

    def test_precision():
      observed = compute_pi()
      expected = 3.142
      assert np.isclose(expected, observed)
    
  • Consider Algorithm Changes: When using imported functions from libraries that undergo algorithmic changes or optimizations, be aware that results might differ. Adjust your tests accordingly to handle such scenarios.

  • Test Return Types: Ensure your test checks return types like float, int, str, np.ndarray, and torch.tensor, to validate the correctness of the unit’s output.

:monkey: Fixtures :monkey:

Fixtures are essential for enhancing data reusability in testing. Instead of defining data locally for each test, fixtures allow you to centrally define and import data into individual unit tests. This approach offers the advantage of standardizing the data used across all tests.

A fixture function is defined using the @pytest.fixture decorator, and it typically yields a value that is used as the setup for a test. Test functions can use the fixture name as an argument, and pytest will automatically invoke the fixture and pass the returned value to the test function :innocent:.

import pytest

@pytest.fixture
def get_sentences():
  return {
    1: "This house is small",
    2: "das haus ist klein" 
  }

Now pass get_sentences as an argument for test_get_ngram and test_get_word_count

def test_get_ngram(get_sentences):
  result = get_ngrams(get_sentences[1], 2)
  assert result == [("This", "house"), ("house", "is"), ("is", "small")]

  result = get_ngrams(get_sentences[2], 3)
  assert result == [("Das", "haus", "ist"), ("haus", "ist", "klein")]
def test_get_word_count(get_sentences):
  result = get_word_count(get_sentences[1])
  assert result["house"] == 1
  assert result["is"] == 1

:owl: Mocking :owl:

Mocks are crucial in testing when dealing with time and resource-intensive methods :bomb:. For example, if you are testing a function responsible for training a model, you can utilize mocks to simulate the model training procedure instead of actually training the model.

MagicMock provides a powerful and flexible way to mock objects in unit tests, allowing you to control the behavior of dependencies and focus on testing the specific functionality of the code under test :dizzy:.

Suppose you have created a custom machine learning class CustomLinearRegression that handles training and saving models to disk. Although you don’t need to test the actual training procedure, you want to ensure that the internal methods called during the process are functioning correctly. In the provided example, the class uses LinearRegression as the model and provides two methods: train and save_model.

  • The train method internally calls the fit method,
  • The save_model internally calls the save method.

The objective is to verify if these internal functions are invoked appropriately without actually fitting the model, as that part is handled by the external library for the LinearRegression model.

Here, we use MagicMock to be returned when _get_model rather an instance of LinearRegression.

from sklearn.linear_model import LinearRegression
from unittest.mock import MagicMock
from unittest.mock import patch

class CustomLinearRegression():
  def __init__(self):
    self.model = LinearRegression()

  def _get_model(self):
    return self.model

  def train(self, x, y):
    model = self._get_model()
    model.fit(x,y)
  
  def save_model(self, path):
    model = self._get_model()
    model.save(f"{path}") 

# Patch `_get_model` method
@patch.object(CustomLinearRegression, '_get_model', return_value=MagicMock())
def test_training_procedure(mock_get_model):
    model = CustomLinearRegression()
    model.train([[1], [4], [9]], [100, 200, 500])
    model.save_model("model.pkl")

    # Assert that the fit and save method on the mock_get_model was called once
    mock_get_model().fit.assert_called_once()
    mock_get_model().save.assert_called_once()

Finally, we assert that the fit and save methods on the mock_get_model were called once each. This ensures that the train and save_model methods of CustomLinearRegression correctly interacted with the model object, as expected.

That concludes this post. Don’t hesitate to reach out if you need further clarification or additional information :wave:.