From 7e0f63eb5f3e955acc0e5d9a228ab932b4a8519b Mon Sep 17 00:00:00 2001 From: smiddlek Date: Fri, 11 Sep 2009 19:57:51 +0000 Subject: Numerous fixes from Google: * Added warning the Mox is not thread-safe - Steve Middlekauff * Fix bug in MultipleTimes group where if a Func is used it is called unnecessarily, due to re-comparing to see if the group has been satisfied - Steve Middlekauff * Use difflib for exceptions - Matt Brown * Add support for mocking iterators - Adam Nadolski * Make __getitem__, __setitem__, and __iter__ work with subclasses of new style clases - Antoine Picard git-svn-id: http://pymox.googlecode.com/svn/trunk@38 b1010a0a-674b-0410-b734-77272b80c875 --- mox.py | 150 +++++++++++++++++++++++++++++++++------------- mox_test.py | 194 ++++++++++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 287 insertions(+), 57 deletions(-) diff --git a/mox.py b/mox.py index 09dafa3..8bb2099 100755 --- a/mox.py +++ b/mox.py @@ -33,6 +33,11 @@ prematurely without calling some cleanup method!) The verify phase ensures that every expected method was called; otherwise, an exception will be raised. +WARNING! Mock objects created by Mox are not thread-safe. If you are +call a mock in multiple threads, it should be guarded by a mutex. + +TODO(stevepm): Add the option to make mocks thread-safe! + Suggested usage / workflow: # Create Mox factory @@ -57,6 +62,7 @@ Suggested usage / workflow: """ from collections import deque +import difflib import inspect import re import types @@ -117,12 +123,17 @@ class UnexpectedMethodCallError(Error): """ Error.__init__(self) - self._unexpected_method = unexpected_method - self._expected = expected + if expected is None: + self._str = "Unexpected method call %s" % (unexpected_method,) + else: + differ = difflib.Differ() + diff = differ.compare(str(unexpected_method).splitlines(True), + str(expected).splitlines(True)) + self._str = ("Unexpected method call. unexpected:- expected:+\n%s" + % ("\n".join(diff),)) def __str__(self): - return "Unexpected method call: %s. Expecting: %s" % \ - (self._unexpected_method, self._expected) + return self._str class UnknownMethodCallError(Error): @@ -174,13 +185,16 @@ class Mox(object): self._mock_objects.append(new_mock) return new_mock - def CreateMockAnything(self): + def CreateMockAnything(self, description=None): """Create a mock that will accept any method calls. This does not enforce an interface. - """ - new_mock = MockAnything() + Args: + description: str. Optionally, a descriptive name for the mock object being + created, for debugging output purposes. + """ + new_mock = MockAnything(description=description) self._mock_objects.append(new_mock) return new_mock @@ -218,10 +232,16 @@ class Mox(object): """ attr_to_replace = getattr(obj, attr_name) + + # Check for a MockAnything. This could cause confusing problems later on. + if attr_to_replace == MockAnything(): + raise TypeError('Cannot mock a MockAnything! Did you remember to ' + 'call UnsetStubs in your previous test?') + if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: stub = self.CreateMock(attr_to_replace) else: - stub = self.CreateMockAnything() + stub = self.CreateMockAnything(description='Stub for %s' % attr_to_replace) self.stubs.Set(obj, attr_name, stub) @@ -269,15 +289,21 @@ class MockAnything: This is helpful for mocking classes that do not provide a public interface. """ - def __init__(self): - """ """ + def __init__(self, description=None): + """Initialize a new MockAnything. + + Args: + description: str. Optionally, a descriptive name for the mock object being + created, for debugging output purposes. + """ + self._description = description self._Reset() def __str__(self): return "" % id(self) def __repr__(self): - return self.__str__() + return '' def __getattr__(self, method_name): """Intercept method calls on this object. @@ -310,7 +336,8 @@ class MockAnything: """ return MockMethod(method_name, self._expected_calls_queue, - self._replay_mode, method_to_mock=method_to_mock) + self._replay_mode, method_to_mock=method_to_mock, + description=self._description) def __nonzero__(self): """Return 1 for nonzero so the mock can be used as a conditional.""" @@ -449,10 +476,8 @@ class MockObject(MockAnything, object): __setitem__. """ - setitem = self._class_to_mock.__dict__.get('__setitem__', None) - # Verify the class supports item assignment. - if setitem is None: + if '__setitem__' not in dir(self._class_to_mock): raise TypeError('object does not support item assignment') # If we are in replay mode then simply call the mock __setitem__ method. @@ -477,13 +502,11 @@ class MockObject(MockAnything, object): Raises: TypeError if the underlying class is not subscriptable. UnexpectedMethodCallError if the object does not expect the call to - __setitem__. + __getitem__. """ - getitem = self._class_to_mock.__dict__.get('__getitem__', None) - # Verify the class supports item assignment. - if getitem is None: + if '__getitem__' not in dir(self._class_to_mock): raise TypeError('unsubscriptable object') # If we are in replay mode then simply call the mock __getitem__ method. @@ -495,6 +518,47 @@ class MockObject(MockAnything, object): # Otherwise, create a mock method __getitem__. return self._CreateMockMethod('__getitem__')(key) + def __iter__(self): + """Provide custom logic for mocking classes that are iterable. + + Returns: + Expected return value in replay mode. A MockMethod object for the + __iter__ method that has already been called if not in replay mode. + + Raises: + TypeError if the underlying class is not iterable. + UnexpectedMethodCallError if the object does not expect the call to + __iter__. + + """ + methods = dir(self._class_to_mock) + + # Verify the class supports iteration. + if '__iter__' not in methods: + # If it doesn't have iter method and we are in replay method, then try to + # iterate using subscripts. + if '__getitem__' not in methods or not self._replay_mode: + raise TypeError('not iterable object') + else: + results = [] + index = 0 + try: + while True: + results.append(self[index]) + index += 1 + except IndexError: + return iter(results) + + # If we are in replay mode then simply call the mock __iter__ method. + if self._replay_mode: + return MockMethod('__iter__', self._expected_calls_queue, + self._replay_mode)() + + + # Otherwise, create a mock method __iter__. + return self._CreateMockMethod('__iter__')() + + def __contains__(self, key): """Provide custom logic for mocking classes that contain items. @@ -525,9 +589,9 @@ class MockObject(MockAnything, object): def __call__(self, *params, **named_params): """Provide custom logic for mocking classes that are callable.""" - # Verify the class we are mocking is callable - callable = self._class_to_mock.__dict__.get('__call__', None) - if callable is None: + # Verify the class we are mocking is callable. + callable = hasattr(self._class_to_mock, '__call__') + if not callable: raise TypeError('Not callable') # Because the call is happening directly on this object instead of a method, @@ -644,7 +708,8 @@ class MockMethod(object): signature) matches the expected method. """ - def __init__(self, method_name, call_queue, replay_mode, method_to_mock=None): + def __init__(self, method_name, call_queue, replay_mode, + method_to_mock=None, description=None): """Construct a new mock method. Args: @@ -654,10 +719,13 @@ class MockMethod(object): # replay_mode: False if we are recording, True if we are verifying calls # against the call queue. # method_to_mock: The actual method being mocked, used for introspection. + # description: optionally, a descriptive name for this method. Typically + # this is equal to the descriptive name of the method's class. method_name: str call_queue: list or deque replay_mode: bool method_to_mock: a method object + description: str or None """ self._name = method_name @@ -665,6 +733,7 @@ class MockMethod(object): if not isinstance(call_queue, deque): self._call_queue = deque(self._call_queue) self._replay_mode = replay_mode + self._description = description self._params = None self._named_params = None @@ -715,6 +784,16 @@ class MockMethod(object): raise AttributeError('MockMethod has no attribute "%s". ' 'Did you remember to put your mocks in replay mode?' % name) + def __iter__(self): + """Raise a TypeError with a helpful message.""" + raise TypeError('MockMethod cannot be iterated. ' + 'Did you remember to put your mocks in replay mode?') + + def next(self): + """Raise a TypeError with a helpful message.""" + raise TypeError('MockMethod cannot be iterated. ' + 'Did you remember to put your mocks in replay mode?') + def _PopNextMethod(self): """Pop the next method from our call queue.""" try: @@ -753,8 +832,10 @@ class MockMethod(object): params = ', '.join( [repr(p) for p in self._params or []] + ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) - desc = "%s(%s) -> %r" % (self._name, params, self._return_value) - return desc + full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value) + if self._description: + full_desc = "%s.%s" % (self._description, full_desc) + return full_desc def __eq__(self, rhs): """Test whether this MockMethod is equivalent to another MockMethod. @@ -1442,7 +1523,7 @@ class MultipleTimesGroup(MethodGroup): def __init__(self, group_name): super(MultipleTimesGroup, self).__init__(group_name) self._methods = set() - self._methods_called = set() + self._methods_left = set() def AddMethod(self, mock_method): """Add a method to this group. @@ -1452,6 +1533,7 @@ class MultipleTimesGroup(MethodGroup): """ self._methods.add(mock_method) + self._methods_left.add(mock_method) def MethodCalled(self, mock_method): """Remove a method call from the group. @@ -1471,10 +1553,9 @@ class MultipleTimesGroup(MethodGroup): # Check to see if this method exists, and if so add it to the set of # called methods. - for method in self._methods: if method == mock_method: - self._methods_called.add(mock_method) + self._methods_left.discard(method) # Always put this group back on top of the queue, because we don't know # when we are done. mock_method._call_queue.appendleft(self) @@ -1488,18 +1569,7 @@ class MultipleTimesGroup(MethodGroup): def IsSatisfied(self): """Return True if all methods in this group are called at least once.""" - # NOTE(psycho): We can't use the simple set difference here because we want - # to match different parameters which are considered the same e.g. IsA(str) - # and some string. This solution is O(n^2) but n should be small. - tmp = self._methods.copy() - for called in self._methods_called: - for expected in tmp: - if called == expected: - tmp.remove(expected) - if not tmp: - return True - break - return False + return len(self._methods_left) == 0 class MoxMetaTestBase(type): diff --git a/mox_test.py b/mox_test.py index 5519a7f..bf806f6 100755 --- a/mox_test.py +++ b/mox_test.py @@ -269,6 +269,7 @@ class IsATest(unittest.TestCase): stringIO = cStringIO.StringIO() self.assert_(isA == stringIO) + class IsAlmostTest(unittest.TestCase): """Verify IsAlmost correctly checks equality of floating point numbers.""" @@ -292,6 +293,7 @@ class IsAlmostTest(unittest.TestCase): self.assertNotEquals(mox.IsAlmost('1.8999999999'), 1.9) self.assertNotEquals(mox.IsAlmost('1.8999999999'), '1.9') + class MockMethodTest(unittest.TestCase): """Test class to verify that the MockMethod class is working correctly.""" @@ -425,6 +427,10 @@ class MockAnythingTest(unittest.TestCase): def setUp(self): self.mock_object = mox.MockAnything() + def testRepr(self): + """Calling repr on a MockAnything instance must work.""" + self.assertEqual('', repr(self.mock_object)) + def testSetupMode(self): """Verify the mock will accept any call.""" self.mock_object.NonsenseCall() @@ -793,10 +799,7 @@ class MockObjectTest(unittest.TestCase): self.assertRaises(mox.ExpectedMethodCallsError, dummy._Verify) def testMockSetItem_ExpectedNoSetItem_Success(self): - """Test that __setitem__() gets mocked in Dummy. - - In this test, _Verify() succeeds. - """ + """Test that __setitem__() gets mocked in Dummy.""" dummy = mox.MockObject(TestClass) # NOT doing dummy['X'] = 'Y' @@ -805,8 +808,6 @@ class MockObjectTest(unittest.TestCase): def call(): dummy['X'] = 'Y' self.assertRaises(mox.UnexpectedMethodCallError, call) - dummy._Verify() - def testMockSetItem_ExpectedNoSetItem_NoSuccess(self): """Test that __setitem__() gets mocked in Dummy. @@ -834,8 +835,25 @@ class MockObjectTest(unittest.TestCase): dummy._Verify() + def testMockSetItem_WithSubClassOfNewStyleClass(self): + class NewStyleTestClass(object): + def __init__(self): + self.my_dict = {} + + def __setitem__(self, key, value): + self.my_dict[key], value + + class TestSubClass(NewStyleTestClass): + pass + + dummy = mox.MockObject(TestSubClass) + dummy[1] = 2 + dummy._Replay() + dummy[1] = 2 + dummy._Verify() + def testMockGetItem_ExpectedGetItem_Success(self): - """Test that __setitem__() gets mocked in Dummy. + """Test that __getitem__() gets mocked in Dummy. In this test, _Verify() succeeds. """ @@ -849,7 +867,7 @@ class MockObjectTest(unittest.TestCase): dummy._Verify() def testMockGetItem_ExpectedGetItem_NoSuccess(self): - """Test that __setitem__() gets mocked in Dummy. + """Test that __getitem__() gets mocked in Dummy. In this test, _Verify() fails. """ @@ -863,10 +881,7 @@ class MockObjectTest(unittest.TestCase): self.assertRaises(mox.ExpectedMethodCallsError, dummy._Verify) def testMockGetItem_ExpectedNoGetItem_NoSuccess(self): - """Test that __setitem__() gets mocked in Dummy. - - In this test, _Verify() succeeds. - """ + """Test that __getitem__() gets mocked in Dummy.""" dummy = mox.MockObject(TestClass) # NOT doing dummy['X'] @@ -875,10 +890,8 @@ class MockObjectTest(unittest.TestCase): def call(): return dummy['X'] self.assertRaises(mox.UnexpectedMethodCallError, call) - dummy._Verify() - def testMockGetItem_ExpectedGetItem_NonmatchingParameters(self): - """Test that __setitem__() fails if other parameters are expected.""" + """Test that __getitem__() fails if other parameters are expected.""" dummy = mox.MockObject(TestClass) dummy['X'].AndReturn('value') @@ -890,6 +903,34 @@ class MockObjectTest(unittest.TestCase): dummy._Verify() + def testMockGetItem_WithSubClassOfNewStyleClass(self): + class NewStyleTestClass(object): + def __getitem__(self, key): + return {1: '1', 2: '2'}[key] + + class TestSubClass(NewStyleTestClass): + pass + + dummy = mox.MockObject(TestSubClass) + dummy[1].AndReturn('3') + + dummy._Replay() + self.assertEquals('3', dummy.__getitem__(1)) + dummy._Verify() + + def testMockIter_ExpectedIter_Success(self): + """Test that __iter__() gets mocked in Dummy. + + In this test, _Verify() succeeds. + """ + dummy = mox.MockObject(TestClass) + iter(dummy).AndReturn(iter(['X', 'Y'])) + + dummy._Replay() + + self.assertEqual([x for x in dummy], ['X', 'Y']) + + dummy._Verify() def testMockContains_ExpectedContains_Success(self): """Test that __contains__ gets mocked in Dummy. @@ -931,6 +972,65 @@ class MockObjectTest(unittest.TestCase): dummy._Verify() + def testMockIter_ExpectedIter_NoSuccess(self): + """Test that __iter__() gets mocked in Dummy. + + In this test, _Verify() fails. + """ + dummy = mox.MockObject(TestClass) + iter(dummy).AndReturn(iter(['X', 'Y'])) + + dummy._Replay() + + # NOT doing self.assertEqual([x for x in dummy], ['X', 'Y']) + + self.assertRaises(mox.ExpectedMethodCallsError, dummy._Verify) + + def testMockIter_ExpectedNoIter_NoSuccess(self): + """Test that __iter__() gets mocked in Dummy.""" + dummy = mox.MockObject(TestClass) + # NOT doing iter(dummy) + + dummy._Replay() + + def call(): return [x for x in dummy] + self.assertRaises(mox.UnexpectedMethodCallError, call) + + def testMockIter_ExpectedGetItem_Success(self): + """Test that __iter__() gets mocked in Dummy using getitem.""" + dummy = mox.MockObject(SubscribtableNonIterableClass) + dummy[0].AndReturn('a') + dummy[1].AndReturn('b') + dummy[2].AndRaise(IndexError) + + dummy._Replay() + self.assertEquals(['a', 'b'], [x for x in dummy]) + dummy._Verify() + + def testMockIter_ExpectedNoGetItem_NoSuccess(self): + """Test that __iter__() gets mocked in Dummy using getitem.""" + dummy = mox.MockObject(SubscribtableNonIterableClass) + # NOT doing dummy[index] + + dummy._Replay() + function = lambda: [x for x in dummy] + self.assertRaises(mox.UnexpectedMethodCallError, function) + + def testMockGetIter_WithSubClassOfNewStyleClass(self): + class NewStyleTestClass(object): + def __iter__(self): + return iter([1, 2, 3]) + + class TestSubClass(NewStyleTestClass): + pass + + dummy = mox.MockObject(TestSubClass) + iter(dummy).AndReturn(iter(['a', 'b'])) + dummy._Replay() + self.assertEquals(['a', 'b'], [x for x in dummy]) + dummy._Verify() + + class MoxTest(unittest.TestCase): """Verify Mox works correctly.""" @@ -979,6 +1079,16 @@ class MoxTest(unittest.TestCase): self.assertEquals("qux", ret_val) self.mox.VerifyAll() + def testInheritedCallableObject(self): + """Test recording calls to an object inheriting from a callable object.""" + mock_obj = self.mox.CreateMock(InheritsFromCallable) + mock_obj("foo").AndReturn("qux") + self.mox.ReplayAll() + + ret_val = mock_obj("foo") + self.assertEquals("qux", ret_val) + self.mox.VerifyAll() + def testCallOnNonCallableObject(self): """Test that you cannot call a non-callable object.""" mock_obj = self.mox.CreateMock(TestClass) @@ -1106,13 +1216,13 @@ class MoxTest(unittest.TestCase): mock_obj.Method(3) mock_obj.Method(3) + self.mox.VerifyAll() + self.assertEquals(9, actual_one) self.assertEquals(9, second_one) # Repeated calls should return same number. self.assertEquals(10, actual_two) self.assertEquals(42, actual_three) - self.mox.VerifyAll() - def testMultipleTimesUsingIsAParameter(self): """Test if MultipleTimesGroup works with a IsA parameter.""" mock_obj = self.mox.CreateMockAnything() @@ -1126,11 +1236,40 @@ class MoxTest(unittest.TestCase): second_one = mock_obj.Method("2") # This tests MultipleTimes. mock_obj.Close() + self.mox.VerifyAll() + self.assertEquals(9, actual_one) self.assertEquals(9, second_one) # Repeated calls should return same number. + def testMutlipleTimesUsingFunc(self): + """Test that the Func is not evaluated more times than necessary. + + If a Func() has side effects, it can cause a passing test to fail. + """ + + self.counter = 0 + def MyFunc(actual_str): + """Increment the counter if actual_str == 'foo'.""" + if actual_str == 'foo': + self.counter += 1 + return True + + mock_obj = self.mox.CreateMockAnything() + mock_obj.Open() + mock_obj.Method(mox.Func(MyFunc)).MultipleTimes() + mock_obj.Close() + self.mox.ReplayAll() + + mock_obj.Open() + mock_obj.Method('foo') + mock_obj.Method('foo') + mock_obj.Method('not-foo') + mock_obj.Close() + self.mox.VerifyAll() + self.assertEquals(2, self.counter) + def testMultipleTimesThreeMethods(self): """Test if MultipleTimesGroup works with three or more methods.""" mock_obj = self.mox.CreateMockAnything() @@ -1268,6 +1407,12 @@ class MoxTest(unittest.TestCase): self.assertEquals('foo', actual) self.failIf(isinstance(test_obj.OtherValidCall, mox.MockAnything)) + def testWarnsUserIfMockingMock(self): + """Test that user is warned if they try to stub out a MockAnything.""" + self.mox.StubOutWithMock(TestClass, 'MyStaticMethod') + self.assertRaises(TypeError, self.mox.StubOutWithMock, TestClass, + 'MyStaticMethod') + def testStubOutObject(self): """Test than object is replaced with a Mock.""" @@ -1301,6 +1446,7 @@ class MoxTest(unittest.TestCase): self.assertEquals('MockMethod has no attribute "ShowMeTheMoney". ' 'Did you remember to put your mocks in replay mode?', str(e)) + class ReplayTest(unittest.TestCase): """Verify Replay works properly.""" @@ -1311,6 +1457,7 @@ class ReplayTest(unittest.TestCase): mox.Replay(mock_obj) self.assertTrue(mock_obj._replay_mode) + class MoxTestBaseTest(unittest.TestCase): """Verify that all tests in a class derived from MoxTestBase are wrapped.""" @@ -1524,6 +1671,8 @@ class TestClass: """Returns True if d contains the key.""" return key in self.d + def __iter__(self): + pass class ChildClass(TestClass): """This inherits from TestClass.""" @@ -1544,5 +1693,16 @@ class CallableClass(object): return param +class SubscribtableNonIterableClass(object): + def __getitem__(self, index): + raise IndexError + + +class InheritsFromCallable(CallableClass): + """This class should also be mockable; it inherits from a callable class.""" + + pass + + if __name__ == '__main__': unittest.main() -- cgit v1.2.1