summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsmiddlek <smiddlek@b1010a0a-674b-0410-b734-77272b80c875>2009-09-11 19:57:51 +0000
committersmiddlek <smiddlek@b1010a0a-674b-0410-b734-77272b80c875>2009-09-11 19:57:51 +0000
commit7e0f63eb5f3e955acc0e5d9a228ab932b4a8519b (patch)
tree1e7014958e541e2bacb04871c0c5151f99a6fe3c
parente01718f94ebbbe8f3197eac58b005dccb1191231 (diff)
downloadmox-7e0f63eb5f3e955acc0e5d9a228ab932b4a8519b.tar.gz
Numerous fixes from Google:
* Added warning the Mox is not thread-safe - Steve Middlekauff * Fix bug in MultipleTimes group where if a Func is used it is called unnecessarily, due to re-comparing to see if the group has been satisfied - Steve Middlekauff * Use difflib for exceptions - Matt Brown * Add support for mocking iterators - Adam Nadolski * Make __getitem__, __setitem__, and __iter__ work with subclasses of new style clases - Antoine Picard git-svn-id: http://pymox.googlecode.com/svn/trunk@38 b1010a0a-674b-0410-b734-77272b80c875
-rwxr-xr-xmox.py150
-rwxr-xr-xmox_test.py194
2 files changed, 287 insertions, 57 deletions
diff --git a/mox.py b/mox.py
index 09dafa3..8bb2099 100755
--- a/mox.py
+++ b/mox.py
@@ -33,6 +33,11 @@ prematurely without calling some cleanup method!) The verify phase
ensures that every expected method was called; otherwise, an exception
will be raised.
+WARNING! Mock objects created by Mox are not thread-safe. If you are
+call a mock in multiple threads, it should be guarded by a mutex.
+
+TODO(stevepm): Add the option to make mocks thread-safe!
+
Suggested usage / workflow:
# Create Mox factory
@@ -57,6 +62,7 @@ Suggested usage / workflow:
"""
from collections import deque
+import difflib
import inspect
import re
import types
@@ -117,12 +123,17 @@ class UnexpectedMethodCallError(Error):
"""
Error.__init__(self)
- self._unexpected_method = unexpected_method
- self._expected = expected
+ if expected is None:
+ self._str = "Unexpected method call %s" % (unexpected_method,)
+ else:
+ differ = difflib.Differ()
+ diff = differ.compare(str(unexpected_method).splitlines(True),
+ str(expected).splitlines(True))
+ self._str = ("Unexpected method call. unexpected:- expected:+\n%s"
+ % ("\n".join(diff),))
def __str__(self):
- return "Unexpected method call: %s. Expecting: %s" % \
- (self._unexpected_method, self._expected)
+ return self._str
class UnknownMethodCallError(Error):
@@ -174,13 +185,16 @@ class Mox(object):
self._mock_objects.append(new_mock)
return new_mock
- def CreateMockAnything(self):
+ def CreateMockAnything(self, description=None):
"""Create a mock that will accept any method calls.
This does not enforce an interface.
- """
- new_mock = MockAnything()
+ Args:
+ description: str. Optionally, a descriptive name for the mock object being
+ created, for debugging output purposes.
+ """
+ new_mock = MockAnything(description=description)
self._mock_objects.append(new_mock)
return new_mock
@@ -218,10 +232,16 @@ class Mox(object):
"""
attr_to_replace = getattr(obj, attr_name)
+
+ # Check for a MockAnything. This could cause confusing problems later on.
+ if attr_to_replace == MockAnything():
+ raise TypeError('Cannot mock a MockAnything! Did you remember to '
+ 'call UnsetStubs in your previous test?')
+
if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
stub = self.CreateMock(attr_to_replace)
else:
- stub = self.CreateMockAnything()
+ stub = self.CreateMockAnything(description='Stub for %s' % attr_to_replace)
self.stubs.Set(obj, attr_name, stub)
@@ -269,15 +289,21 @@ class MockAnything:
This is helpful for mocking classes that do not provide a public interface.
"""
- def __init__(self):
- """ """
+ def __init__(self, description=None):
+ """Initialize a new MockAnything.
+
+ Args:
+ description: str. Optionally, a descriptive name for the mock object being
+ created, for debugging output purposes.
+ """
+ self._description = description
self._Reset()
def __str__(self):
return "<MockAnything instance at %s>" % id(self)
def __repr__(self):
- return self.__str__()
+ return '<MockAnything instance>'
def __getattr__(self, method_name):
"""Intercept method calls on this object.
@@ -310,7 +336,8 @@ class MockAnything:
"""
return MockMethod(method_name, self._expected_calls_queue,
- self._replay_mode, method_to_mock=method_to_mock)
+ self._replay_mode, method_to_mock=method_to_mock,
+ description=self._description)
def __nonzero__(self):
"""Return 1 for nonzero so the mock can be used as a conditional."""
@@ -449,10 +476,8 @@ class MockObject(MockAnything, object):
__setitem__.
"""
- setitem = self._class_to_mock.__dict__.get('__setitem__', None)
-
# Verify the class supports item assignment.
- if setitem is None:
+ if '__setitem__' not in dir(self._class_to_mock):
raise TypeError('object does not support item assignment')
# If we are in replay mode then simply call the mock __setitem__ method.
@@ -477,13 +502,11 @@ class MockObject(MockAnything, object):
Raises:
TypeError if the underlying class is not subscriptable.
UnexpectedMethodCallError if the object does not expect the call to
- __setitem__.
+ __getitem__.
"""
- getitem = self._class_to_mock.__dict__.get('__getitem__', None)
-
# Verify the class supports item assignment.
- if getitem is None:
+ if '__getitem__' not in dir(self._class_to_mock):
raise TypeError('unsubscriptable object')
# If we are in replay mode then simply call the mock __getitem__ method.
@@ -495,6 +518,47 @@ class MockObject(MockAnything, object):
# Otherwise, create a mock method __getitem__.
return self._CreateMockMethod('__getitem__')(key)
+ def __iter__(self):
+ """Provide custom logic for mocking classes that are iterable.
+
+ Returns:
+ Expected return value in replay mode. A MockMethod object for the
+ __iter__ method that has already been called if not in replay mode.
+
+ Raises:
+ TypeError if the underlying class is not iterable.
+ UnexpectedMethodCallError if the object does not expect the call to
+ __iter__.
+
+ """
+ methods = dir(self._class_to_mock)
+
+ # Verify the class supports iteration.
+ if '__iter__' not in methods:
+ # If it doesn't have iter method and we are in replay method, then try to
+ # iterate using subscripts.
+ if '__getitem__' not in methods or not self._replay_mode:
+ raise TypeError('not iterable object')
+ else:
+ results = []
+ index = 0
+ try:
+ while True:
+ results.append(self[index])
+ index += 1
+ except IndexError:
+ return iter(results)
+
+ # If we are in replay mode then simply call the mock __iter__ method.
+ if self._replay_mode:
+ return MockMethod('__iter__', self._expected_calls_queue,
+ self._replay_mode)()
+
+
+ # Otherwise, create a mock method __iter__.
+ return self._CreateMockMethod('__iter__')()
+
+
def __contains__(self, key):
"""Provide custom logic for mocking classes that contain items.
@@ -525,9 +589,9 @@ class MockObject(MockAnything, object):
def __call__(self, *params, **named_params):
"""Provide custom logic for mocking classes that are callable."""
- # Verify the class we are mocking is callable
- callable = self._class_to_mock.__dict__.get('__call__', None)
- if callable is None:
+ # Verify the class we are mocking is callable.
+ callable = hasattr(self._class_to_mock, '__call__')
+ if not callable:
raise TypeError('Not callable')
# Because the call is happening directly on this object instead of a method,
@@ -644,7 +708,8 @@ class MockMethod(object):
signature) matches the expected method.
"""
- def __init__(self, method_name, call_queue, replay_mode, method_to_mock=None):
+ def __init__(self, method_name, call_queue, replay_mode,
+ method_to_mock=None, description=None):
"""Construct a new mock method.
Args:
@@ -654,10 +719,13 @@ class MockMethod(object):
# replay_mode: False if we are recording, True if we are verifying calls
# against the call queue.
# method_to_mock: The actual method being mocked, used for introspection.
+ # description: optionally, a descriptive name for this method. Typically
+ # this is equal to the descriptive name of the method's class.
method_name: str
call_queue: list or deque
replay_mode: bool
method_to_mock: a method object
+ description: str or None
"""
self._name = method_name
@@ -665,6 +733,7 @@ class MockMethod(object):
if not isinstance(call_queue, deque):
self._call_queue = deque(self._call_queue)
self._replay_mode = replay_mode
+ self._description = description
self._params = None
self._named_params = None
@@ -715,6 +784,16 @@ class MockMethod(object):
raise AttributeError('MockMethod has no attribute "%s". '
'Did you remember to put your mocks in replay mode?' % name)
+ def __iter__(self):
+ """Raise a TypeError with a helpful message."""
+ raise TypeError('MockMethod cannot be iterated. '
+ 'Did you remember to put your mocks in replay mode?')
+
+ def next(self):
+ """Raise a TypeError with a helpful message."""
+ raise TypeError('MockMethod cannot be iterated. '
+ 'Did you remember to put your mocks in replay mode?')
+
def _PopNextMethod(self):
"""Pop the next method from our call queue."""
try:
@@ -753,8 +832,10 @@ class MockMethod(object):
params = ', '.join(
[repr(p) for p in self._params or []] +
['%s=%r' % x for x in sorted((self._named_params or {}).items())])
- desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
- return desc
+ full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
+ if self._description:
+ full_desc = "%s.%s" % (self._description, full_desc)
+ return full_desc
def __eq__(self, rhs):
"""Test whether this MockMethod is equivalent to another MockMethod.
@@ -1442,7 +1523,7 @@ class MultipleTimesGroup(MethodGroup):
def __init__(self, group_name):
super(MultipleTimesGroup, self).__init__(group_name)
self._methods = set()
- self._methods_called = set()
+ self._methods_left = set()
def AddMethod(self, mock_method):
"""Add a method to this group.
@@ -1452,6 +1533,7 @@ class MultipleTimesGroup(MethodGroup):
"""
self._methods.add(mock_method)
+ self._methods_left.add(mock_method)
def MethodCalled(self, mock_method):
"""Remove a method call from the group.
@@ -1471,10 +1553,9 @@ class MultipleTimesGroup(MethodGroup):
# Check to see if this method exists, and if so add it to the set of
# called methods.
-
for method in self._methods:
if method == mock_method:
- self._methods_called.add(mock_method)
+ self._methods_left.discard(method)
# Always put this group back on top of the queue, because we don't know
# when we are done.
mock_method._call_queue.appendleft(self)
@@ -1488,18 +1569,7 @@ class MultipleTimesGroup(MethodGroup):
def IsSatisfied(self):
"""Return True if all methods in this group are called at least once."""
- # NOTE(psycho): We can't use the simple set difference here because we want
- # to match different parameters which are considered the same e.g. IsA(str)
- # and some string. This solution is O(n^2) but n should be small.
- tmp = self._methods.copy()
- for called in self._methods_called:
- for expected in tmp:
- if called == expected:
- tmp.remove(expected)
- if not tmp:
- return True
- break
- return False
+ return len(self._methods_left) == 0
class MoxMetaTestBase(type):
diff --git a/mox_test.py b/mox_test.py
index 5519a7f..bf806f6 100755
--- a/mox_test.py
+++ b/mox_test.py
@@ -269,6 +269,7 @@ class IsATest(unittest.TestCase):
stringIO = cStringIO.StringIO()
self.assert_(isA == stringIO)
+
class IsAlmostTest(unittest.TestCase):
"""Verify IsAlmost correctly checks equality of floating point numbers."""
@@ -292,6 +293,7 @@ class IsAlmostTest(unittest.TestCase):
self.assertNotEquals(mox.IsAlmost('1.8999999999'), 1.9)
self.assertNotEquals(mox.IsAlmost('1.8999999999'), '1.9')
+
class MockMethodTest(unittest.TestCase):
"""Test class to verify that the MockMethod class is working correctly."""
@@ -425,6 +427,10 @@ class MockAnythingTest(unittest.TestCase):
def setUp(self):
self.mock_object = mox.MockAnything()
+ def testRepr(self):
+ """Calling repr on a MockAnything instance must work."""
+ self.assertEqual('<MockAnything instance>', repr(self.mock_object))
+
def testSetupMode(self):
"""Verify the mock will accept any call."""
self.mock_object.NonsenseCall()
@@ -793,10 +799,7 @@ class MockObjectTest(unittest.TestCase):
self.assertRaises(mox.ExpectedMethodCallsError, dummy._Verify)
def testMockSetItem_ExpectedNoSetItem_Success(self):
- """Test that __setitem__() gets mocked in Dummy.
-
- In this test, _Verify() succeeds.
- """
+ """Test that __setitem__() gets mocked in Dummy."""
dummy = mox.MockObject(TestClass)
# NOT doing dummy['X'] = 'Y'
@@ -805,8 +808,6 @@ class MockObjectTest(unittest.TestCase):
def call(): dummy['X'] = 'Y'
self.assertRaises(mox.UnexpectedMethodCallError, call)
- dummy._Verify()
-
def testMockSetItem_ExpectedNoSetItem_NoSuccess(self):
"""Test that __setitem__() gets mocked in Dummy.
@@ -834,8 +835,25 @@ class MockObjectTest(unittest.TestCase):
dummy._Verify()
+ def testMockSetItem_WithSubClassOfNewStyleClass(self):
+ class NewStyleTestClass(object):
+ def __init__(self):
+ self.my_dict = {}
+
+ def __setitem__(self, key, value):
+ self.my_dict[key], value
+
+ class TestSubClass(NewStyleTestClass):
+ pass
+
+ dummy = mox.MockObject(TestSubClass)
+ dummy[1] = 2
+ dummy._Replay()
+ dummy[1] = 2
+ dummy._Verify()
+
def testMockGetItem_ExpectedGetItem_Success(self):
- """Test that __setitem__() gets mocked in Dummy.
+ """Test that __getitem__() gets mocked in Dummy.
In this test, _Verify() succeeds.
"""
@@ -849,7 +867,7 @@ class MockObjectTest(unittest.TestCase):
dummy._Verify()
def testMockGetItem_ExpectedGetItem_NoSuccess(self):
- """Test that __setitem__() gets mocked in Dummy.
+ """Test that __getitem__() gets mocked in Dummy.
In this test, _Verify() fails.
"""
@@ -863,10 +881,7 @@ class MockObjectTest(unittest.TestCase):
self.assertRaises(mox.ExpectedMethodCallsError, dummy._Verify)
def testMockGetItem_ExpectedNoGetItem_NoSuccess(self):
- """Test that __setitem__() gets mocked in Dummy.
-
- In this test, _Verify() succeeds.
- """
+ """Test that __getitem__() gets mocked in Dummy."""
dummy = mox.MockObject(TestClass)
# NOT doing dummy['X']
@@ -875,10 +890,8 @@ class MockObjectTest(unittest.TestCase):
def call(): return dummy['X']
self.assertRaises(mox.UnexpectedMethodCallError, call)
- dummy._Verify()
-
def testMockGetItem_ExpectedGetItem_NonmatchingParameters(self):
- """Test that __setitem__() fails if other parameters are expected."""
+ """Test that __getitem__() fails if other parameters are expected."""
dummy = mox.MockObject(TestClass)
dummy['X'].AndReturn('value')
@@ -890,6 +903,34 @@ class MockObjectTest(unittest.TestCase):
dummy._Verify()
+ def testMockGetItem_WithSubClassOfNewStyleClass(self):
+ class NewStyleTestClass(object):
+ def __getitem__(self, key):
+ return {1: '1', 2: '2'}[key]
+
+ class TestSubClass(NewStyleTestClass):
+ pass
+
+ dummy = mox.MockObject(TestSubClass)
+ dummy[1].AndReturn('3')
+
+ dummy._Replay()
+ self.assertEquals('3', dummy.__getitem__(1))
+ dummy._Verify()
+
+ def testMockIter_ExpectedIter_Success(self):
+ """Test that __iter__() gets mocked in Dummy.
+
+ In this test, _Verify() succeeds.
+ """
+ dummy = mox.MockObject(TestClass)
+ iter(dummy).AndReturn(iter(['X', 'Y']))
+
+ dummy._Replay()
+
+ self.assertEqual([x for x in dummy], ['X', 'Y'])
+
+ dummy._Verify()
def testMockContains_ExpectedContains_Success(self):
"""Test that __contains__ gets mocked in Dummy.
@@ -931,6 +972,65 @@ class MockObjectTest(unittest.TestCase):
dummy._Verify()
+ def testMockIter_ExpectedIter_NoSuccess(self):
+ """Test that __iter__() gets mocked in Dummy.
+
+ In this test, _Verify() fails.
+ """
+ dummy = mox.MockObject(TestClass)
+ iter(dummy).AndReturn(iter(['X', 'Y']))
+
+ dummy._Replay()
+
+ # NOT doing self.assertEqual([x for x in dummy], ['X', 'Y'])
+
+ self.assertRaises(mox.ExpectedMethodCallsError, dummy._Verify)
+
+ def testMockIter_ExpectedNoIter_NoSuccess(self):
+ """Test that __iter__() gets mocked in Dummy."""
+ dummy = mox.MockObject(TestClass)
+ # NOT doing iter(dummy)
+
+ dummy._Replay()
+
+ def call(): return [x for x in dummy]
+ self.assertRaises(mox.UnexpectedMethodCallError, call)
+
+ def testMockIter_ExpectedGetItem_Success(self):
+ """Test that __iter__() gets mocked in Dummy using getitem."""
+ dummy = mox.MockObject(SubscribtableNonIterableClass)
+ dummy[0].AndReturn('a')
+ dummy[1].AndReturn('b')
+ dummy[2].AndRaise(IndexError)
+
+ dummy._Replay()
+ self.assertEquals(['a', 'b'], [x for x in dummy])
+ dummy._Verify()
+
+ def testMockIter_ExpectedNoGetItem_NoSuccess(self):
+ """Test that __iter__() gets mocked in Dummy using getitem."""
+ dummy = mox.MockObject(SubscribtableNonIterableClass)
+ # NOT doing dummy[index]
+
+ dummy._Replay()
+ function = lambda: [x for x in dummy]
+ self.assertRaises(mox.UnexpectedMethodCallError, function)
+
+ def testMockGetIter_WithSubClassOfNewStyleClass(self):
+ class NewStyleTestClass(object):
+ def __iter__(self):
+ return iter([1, 2, 3])
+
+ class TestSubClass(NewStyleTestClass):
+ pass
+
+ dummy = mox.MockObject(TestSubClass)
+ iter(dummy).AndReturn(iter(['a', 'b']))
+ dummy._Replay()
+ self.assertEquals(['a', 'b'], [x for x in dummy])
+ dummy._Verify()
+
+
class MoxTest(unittest.TestCase):
"""Verify Mox works correctly."""
@@ -979,6 +1079,16 @@ class MoxTest(unittest.TestCase):
self.assertEquals("qux", ret_val)
self.mox.VerifyAll()
+ def testInheritedCallableObject(self):
+ """Test recording calls to an object inheriting from a callable object."""
+ mock_obj = self.mox.CreateMock(InheritsFromCallable)
+ mock_obj("foo").AndReturn("qux")
+ self.mox.ReplayAll()
+
+ ret_val = mock_obj("foo")
+ self.assertEquals("qux", ret_val)
+ self.mox.VerifyAll()
+
def testCallOnNonCallableObject(self):
"""Test that you cannot call a non-callable object."""
mock_obj = self.mox.CreateMock(TestClass)
@@ -1106,13 +1216,13 @@ class MoxTest(unittest.TestCase):
mock_obj.Method(3)
mock_obj.Method(3)
+ self.mox.VerifyAll()
+
self.assertEquals(9, actual_one)
self.assertEquals(9, second_one) # Repeated calls should return same number.
self.assertEquals(10, actual_two)
self.assertEquals(42, actual_three)
- self.mox.VerifyAll()
-
def testMultipleTimesUsingIsAParameter(self):
"""Test if MultipleTimesGroup works with a IsA parameter."""
mock_obj = self.mox.CreateMockAnything()
@@ -1126,11 +1236,40 @@ class MoxTest(unittest.TestCase):
second_one = mock_obj.Method("2") # This tests MultipleTimes.
mock_obj.Close()
+ self.mox.VerifyAll()
+
self.assertEquals(9, actual_one)
self.assertEquals(9, second_one) # Repeated calls should return same number.
+ def testMutlipleTimesUsingFunc(self):
+ """Test that the Func is not evaluated more times than necessary.
+
+ If a Func() has side effects, it can cause a passing test to fail.
+ """
+
+ self.counter = 0
+ def MyFunc(actual_str):
+ """Increment the counter if actual_str == 'foo'."""
+ if actual_str == 'foo':
+ self.counter += 1
+ return True
+
+ mock_obj = self.mox.CreateMockAnything()
+ mock_obj.Open()
+ mock_obj.Method(mox.Func(MyFunc)).MultipleTimes()
+ mock_obj.Close()
+ self.mox.ReplayAll()
+
+ mock_obj.Open()
+ mock_obj.Method('foo')
+ mock_obj.Method('foo')
+ mock_obj.Method('not-foo')
+ mock_obj.Close()
+
self.mox.VerifyAll()
+ self.assertEquals(2, self.counter)
+
def testMultipleTimesThreeMethods(self):
"""Test if MultipleTimesGroup works with three or more methods."""
mock_obj = self.mox.CreateMockAnything()
@@ -1268,6 +1407,12 @@ class MoxTest(unittest.TestCase):
self.assertEquals('foo', actual)
self.failIf(isinstance(test_obj.OtherValidCall, mox.MockAnything))
+ def testWarnsUserIfMockingMock(self):
+ """Test that user is warned if they try to stub out a MockAnything."""
+ self.mox.StubOutWithMock(TestClass, 'MyStaticMethod')
+ self.assertRaises(TypeError, self.mox.StubOutWithMock, TestClass,
+ 'MyStaticMethod')
+
def testStubOutObject(self):
"""Test than object is replaced with a Mock."""
@@ -1301,6 +1446,7 @@ class MoxTest(unittest.TestCase):
self.assertEquals('MockMethod has no attribute "ShowMeTheMoney". '
'Did you remember to put your mocks in replay mode?', str(e))
+
class ReplayTest(unittest.TestCase):
"""Verify Replay works properly."""
@@ -1311,6 +1457,7 @@ class ReplayTest(unittest.TestCase):
mox.Replay(mock_obj)
self.assertTrue(mock_obj._replay_mode)
+
class MoxTestBaseTest(unittest.TestCase):
"""Verify that all tests in a class derived from MoxTestBase are wrapped."""
@@ -1524,6 +1671,8 @@ class TestClass:
"""Returns True if d contains the key."""
return key in self.d
+ def __iter__(self):
+ pass
class ChildClass(TestClass):
"""This inherits from TestClass."""
@@ -1544,5 +1693,16 @@ class CallableClass(object):
return param
+class SubscribtableNonIterableClass(object):
+ def __getitem__(self, index):
+ raise IndexError
+
+
+class InheritsFromCallable(CallableClass):
+ """This class should also be mockable; it inherits from a callable class."""
+
+ pass
+
+
if __name__ == '__main__':
unittest.main()