Mocking in Python

by Aaron Lelevier

About Me

Full time Python Django Dev for 2.5 years

Programmed in Python for 5 years

Other Languages

Recently Machine Learning

Inspiration

  • Mock for long running requests
  • Can't always make live requests

    • 406 - duplicate request error
    • for predictable response data

Assumptions for this talk

# pip install mock

from mock import MagicMock, patch

Other Assumptions

  • Python 3
  • All files are at the same folder level

Mock at the call site

# bar.py
class Bar(object):

    def biz(self):
        pass

# foo.py
from bar import Bar

def foo():
    Bar().biz()

# test.py
import unittest
from mock import patch
from foo import foo

class MyTest(unittest.TestCase):

    @patch("foo.Bar.biz") # not -> @patch("bar.Bar.biz")
    def test_foo(self, mock_biz):
        self.assertFalse(mock_biz.called)

        foo()

        self.assertTrue(mock_biz.called)

Mock two things

# bar.py
import requests

class Bar(object):

    def sync(self, id, query_first):
        if query_first:
            requests.get('/remote/api/{id}'.format(id=id))

        requests.put('/remote/other/api/{id}'.format(id=id),
                                                     data=current_data())

# test.py
import unittest
from mock import patch
from bar import Bar

class MyTest(unittest.TestCase):

    @patch("bar.requests.get")
    @patch("bar.requests.put")
    def test_foo(self, mock_put, mock_get):
        Bar.sync(id=42, query_first=False)

        self.assertFalse(mock_get.called)
        self.assertTrue(mock_put.called)

Mock as an Argument Captor

# bar.py
class Bar(object):
    def biz(self, url, method, data, headers):
        pass

# foo.py
from bar import Bar

def foo(url, method='GET', data=None, headers=None):
    Bar().biz(url, method, data, headers)

# test.py
class MyTest(unittest.TestCase):

    @patch("foo.Bar.biz")
    def test_foo(self, mock_biz):
        url = '/api/users/{id}'.format(id=1)
        data = {'phone_number': '+17025551000'}
        method = 'PUT'
        headers = {"Authorization": "JWT <your_token>"}

        foo(url, method, data=data, headers=headers)

        self.assertFalse(mock_biz.called)
        self.assertEqual(mock_biz.call_count, 1)
        self.assertEqual(mock_biz.call_args[0][0], url)
        self.assertEqual(mock_biz.call_args[0][1], method)
        self.assertEqual(mock_biz.call_args[1]['data'], data)
        self.assertEqual(mock_biz.call_args[1]['headers'], headers)

Mock a return value

# bar.py
class Bar(object):

    def biz(self):
        return 1

# foo.py
from bar import Bar

def foo():
    return Bar().biz()

# test.py
class MyTest(unittest.TestCase):

    @patch("foo.Bar.biz")
    def test_foo(self, mock_biz):
        expected_value = 2
        mock_biz.return_value = expected_value

        ret = foo()

        self.assertEqual(ret, expected_value)

Mock multiple return values

# bar.py
class Bar(object):
    def biz(self, i):
        return expensive_computation(i)

    def expensive_computation(self, i):
        pass

# foo.py
def foo():
    bar = Bar()
    for i in range(2):
        value = bar.biz(i)
        process_expensive_value(value)

# test.py
class MyTest(unittest.TestCase):
    @patch("bar.Bar.expensive_computation")
    @patch("foo.process_expensive_value")
    def test_foo(self, mock_process_exp_val, mock_exp_comp):
        value1 = 1
        value2 = 2
        mock_exp_comp.side_effect = [value1, value2]

        foo()

        self.assertTrue(mock_exp_comp.called)
        self.assertEqual(mock_exp_comp.call_count, 2)
        self.assertEqual(mock_process_exp_val.call_args_list[0][0][0], value1)
        self.assertEqual(mock_process_exp_val.call_args_list[1][0][0], value2)

Mock an Exception

# bar.py
class Bar(object):
    def biz(self):
        if some_condition():
            raise CustomException()

class CustomException(Exception):
    pass

# foo.py
from bar import Bar

def foo():
    Bar().biz()

# test.py
class MyTest(unittest.TestCase):

    @patch("foo.Bar.biz")
    def test_foo(self, mock_biz):
        mock_biz.side_effect = CustomException()

        with self.assertRaises(CustomException):
            foo()
# bar.py
import requests

class Bar(object):
    def biz(self):
        return requests.get('/api/users/')

# foo.py
from bar import Bar

def foo():
    response = Bar().biz()
    if response.status_code == 200:
        data = json.loads(response.content.decode('utf8'))
        process_users_data(data)

def process_users_data(data):
    pass

# test.py
from pretend import stub

class MyTest(unittest.TestCase):
    @patch("bar.requests.get")
    @patch("foo.process_users_data")
    def test_foo(self, mock_process_users_data, mock_get):
        fake_reponse = stub(status_code=200,
                            content=json.dumps({'users': 'data'}).encode('utf8'))
        mock_get.return_value = fake_reponse
        foo()
        self.assertTrue(mock_process_users_data.called)
        self.assertEqual(
            mock_process_users_data.call_args[0][0],
            json.loads(fake_reponse.content.decode('utf8')))

The End

Questions?