Machine Learning System Testing: A Guide to Writing Unit Tests
In the upcoming series of blog posts, I’ll be focusing on guiding you through the process of creating tests specifically for machine learning systems. This particular blog will cover the basics of writing unit tests
for machine learning functions. We’ll start with a simple introduction to pytest
and gradually move on to more advanced concepts, including using fixtures
and mocking
techniques. This post is perfect for folks who are new to testing ML systems .
You can install pytest using the following command,
pip install -U pytest
Next up, make a file called test_functions.py
. This file will contain the functions we’re going to talk about in this post.
When you’re ready to test, simply run pytest
in your terminal to kick off the unit tests .
Pytest basics
Create a test function to verify that the target function returns the string hello world!
as its output.
def greet():
return "hello world!"
Now, we need to test the above function
def test_greet():
assert greet() == "hello world!"
collected 1 item
test_sample.py . [100%]
=============================================================================== 1 passed in 0.00s ================================================================================
You should witness the test passing
. Afterward, make a modification to the greeting message within the function and then attempt to run the tests. You should then observe them
failing
.
A few more tests,
Create a test function to verify the addition of two numbers within the function.
def add(a, b):
return a+b
def test_add():
result = add(1, 4)
assert type(result) == int
assert result == 5
Create a test function to test a function that produces n-grams for a given sentence and a specified value of n.
def get_ngrams(sentence, n):
tokens = sentence.split(" ")
n_grams = []
for i in range(len(tokens)-n+1):
n_grams.append(tuple(tokens[i:i+n]))
return n_grams
def test_get_ngram():
result = get_ngrams("Good morning world", 2)
assert len(result) == 2
assert result == [("Good", "morning"), ("morning", "world")]
result = get_ngrams("Good morning world", 3)
assert len(result)== 1
assert result == [("Good", "morning", "world")]
Create a test function to test a function that calculates the word count for a given sentence.
def get_word_count(sentence):
words = sentence.split(" ")
word_count = Counter(words)
return word_count
def test_get_word_count():
result = get_word_count("Good world world")
assert result["Good"] == 1
assert result["world"] == 2
How to improve the quality of your unit tests?
-
Diverse Test Cases: Define a range of test cases that should produce both
successful
andfailed
outcomes. This ensures thorough testing of the unit’s behavior. -
Floating Point Precision: If the function returns floating-point values, pay attention to comparing the
precision
of the results to ensure reproducibility. Changes in dependent libraries or optimizations might impact the results, so precise comparisons are essential. Check out np.isclose to test the precision based test.
You should do something like this,
def test_precision():
observed = compute_pi()
expected = 3.142
assert np.isclose(expected, observed)
-
Consider Algorithm Changes: When using imported functions from libraries that undergo algorithmic changes or optimizations, be aware that results might differ. Adjust your tests accordingly to handle such scenarios.
-
Test Return Types: Ensure your test checks return types like
float
,int
,str
,np.ndarray
, andtorch.tensor
, to validate the correctness of the unit’s output.
Fixtures
Fixtures are essential for enhancing data reusability in testing. Instead of defining data locally for each test, fixtures allow you to centrally define and import data into individual unit tests. This approach offers the advantage of standardizing the data used across all tests.
A fixture function is defined using the @pytest.fixture
decorator, and it typically yields a value that is used as the setup for a test. Test functions can use the fixture name as an argument, and pytest will automatically invoke the fixture and pass the returned value to the test function .
import pytest
@pytest.fixture
def get_sentences():
return {
1: "This house is small",
2: "das haus ist klein"
}
Now pass get_sentences
as an argument for test_get_ngram
and test_get_word_count
def test_get_ngram(get_sentences):
result = get_ngrams(get_sentences[1], 2)
assert result == [("This", "house"), ("house", "is"), ("is", "small")]
result = get_ngrams(get_sentences[2], 3)
assert result == [("Das", "haus", "ist"), ("haus", "ist", "klein")]
def test_get_word_count(get_sentences):
result = get_word_count(get_sentences[1])
assert result["house"] == 1
assert result["is"] == 1
Mocking
Mocks
are crucial in testing when dealing with time
and resource-intensive
methods . For example, if you are testing a function responsible for training a model, you can utilize mocks to simulate the model training procedure instead of actually training the model.
MagicMock
provides a powerful and flexible way to mock objects in unit tests, allowing you to control the behavior of dependencies and focus on testing the specific functionality of the code under test .
Suppose you have created a custom machine learning class CustomLinearRegression
that handles training and saving models to disk. Although you don’t need to test the actual training procedure, you want to ensure that the internal methods called during the process are functioning correctly. In the provided example, the class uses LinearRegression
as the model and provides two methods: train
and save_model
.
- The
train
method internally calls thefit
method, - The
save_model
internally calls thesave
method.
The objective is to verify if these internal functions are invoked appropriately without actually fitting the model, as that part is handled by the external library for the LinearRegression
model.
Here, we use MagicMock
to be returned when _get_model
rather an instance of LinearRegression
.
from sklearn.linear_model import LinearRegression
from unittest.mock import MagicMock
from unittest.mock import patch
class CustomLinearRegression():
def __init__(self):
self.model = LinearRegression()
def _get_model(self):
return self.model
def train(self, x, y):
model = self._get_model()
model.fit(x,y)
def save_model(self, path):
model = self._get_model()
model.save(f"{path}")
# Patch `_get_model` method
@patch.object(CustomLinearRegression, '_get_model', return_value=MagicMock())
def test_training_procedure(mock_get_model):
model = CustomLinearRegression()
model.train([[1], [4], [9]], [100, 200, 500])
model.save_model("model.pkl")
# Assert that the fit and save method on the mock_get_model was called once
mock_get_model().fit.assert_called_once()
mock_get_model().save.assert_called_once()
Finally, we assert that the fit
and save
methods on the mock_get_model
were called once each. This ensures that the train
and save_model
methods of CustomLinearRegression
correctly interacted with the model object, as expected.
That concludes this post. Don’t hesitate to reach out if you need further clarification or additional information .