From 4ea08c31b9acea16a21d751c4ca419346323af11 Mon Sep 17 00:00:00 2001 From: smiddlek Date: Wed, 12 Nov 2008 19:01:43 +0000 Subject: Patch from Matt Brown (mdbrow@gmail.com) to inspect the arguments of mocked methods. Alright! :) --- mox.py | 117 ++++++++++++++++++++++++++++++++++++++++++++++++--- mox_test.py | 138 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 5 deletions(-) diff --git a/mox.py b/mox.py index 2fddc0b..23348fd 100755 --- a/mox.py +++ b/mox.py @@ -57,6 +57,7 @@ Suggested usage / workflow: """ from collections import deque +import inspect import re import types import unittest @@ -289,19 +290,21 @@ class MockAnything: return self._CreateMockMethod(method_name) - def _CreateMockMethod(self, method_name): + def _CreateMockMethod(self, method_name, method_to_mock=None): """Create a new mock method call and return it. Args: - # method name: the name of the method being called. + # method_name: the name of the method being called. + # method_to_mock: The actual method being mocked, used for introspection. method_name: str + method_to_mock: a method object Returns: A new MockMethod aware of MockAnything's state (record or replay). """ return MockMethod(method_name, self._expected_calls_queue, - self._replay_mode) + self._replay_mode, method_to_mock=method_to_mock) def __nonzero__(self): """Return 1 for nonzero so the mock can be used as a conditional.""" @@ -409,7 +412,9 @@ class MockObject(MockAnything, object): return getattr(self._class_to_mock, name) if name in self._known_methods: - return self._CreateMockMethod(name) + return self._CreateMockMethod( + name, + method_to_mock=getattr(self._class_to_mock, name)) raise UnknownMethodCallError(name) @@ -531,6 +536,99 @@ class MockObject(MockAnything, object): return self._class_to_mock +class MethodCallChecker(object): + """Ensures that methods are called correctly.""" + + _NEEDED, _DEFAULT, _GIVEN = range(3) + + def __init__(self, method): + """Creates a checker. + + Args: + # method: A method to check. + method: function + + Raises: + ValueError: method could not be inspected, so checks aren't possible. + Some methods and functions like built-ins can't be inspected. + """ + try: + self._args, varargs, varkw, defaults = inspect.getargspec(method) + except TypeError: + raise ValueError('Could not get argument specification for %r' + % (method,)) + if inspect.ismethod(method): + self._args = self._args[1:] # Skip 'self'. + self._method = method + + self._has_varargs = varargs is not None + self._has_varkw = varkw is not None + if defaults is None: + self._required_args = self._args + self._default_args = [] + else: + self._required_args = self._args[:-len(defaults)] + self._default_args = self._args[-len(defaults):] + + def _RecordArgumentGiven(self, arg_name, arg_status): + """Mark an argument as being given. + + Args: + # arg_name: The name of the argument to mark in arg_status. + # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN. + arg_name: string + arg_status: dict + + Raises: + AttributeError: arg_name is already marked as _GIVEN. + """ + if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN: + raise AttributeError('%s provided more than once' % (arg_name,)) + arg_status[arg_name] = MethodCallChecker._GIVEN + + def Check(self, params, named_params): + """Ensures that the parameters used while recording a call are valid. + + Args: + # params: A list of positional parameters. + # named_params: A dict of named parameters. + params: list + named_params: dict + + Raises: + AttributeError: the given parameters don't work with the given method. + """ + arg_status = dict((a, MethodCallChecker._NEEDED) + for a in self._required_args) + for arg in self._default_args: + arg_status[arg] = MethodCallChecker._DEFAULT + + # Check that each positional param is valid. + for i in range(len(params)): + try: + arg_name = self._args[i] + except IndexError: + if not self._has_varargs: + raise AttributeError('%s does not take %d or more positional ' + 'arguments' % (self._method.__name__, i)) + else: + self._RecordArgumentGiven(arg_name, arg_status) + + # Check each keyword argument. + for arg_name in named_params: + if arg_name not in arg_status and not self._has_varkw: + raise AttributeError('%s is not expecting keyword argument %s' + % (self._method.__name__, arg_name)) + self._RecordArgumentGiven(arg_name, arg_status) + + # Ensure all the required arguments have been given. + still_needed = [k for k, v in arg_status.iteritems() + if v == MethodCallChecker._NEEDED] + if still_needed: + raise AttributeError('No values given for arguments %s' + % (' '.join(sorted(still_needed)))) + + class MockMethod(object): """Callable mock method. @@ -540,7 +638,7 @@ class MockMethod(object): signature) matches the expected method. """ - def __init__(self, method_name, call_queue, replay_mode): + def __init__(self, method_name, call_queue, replay_mode, method_to_mock=None): """Construct a new mock method. Args: @@ -549,9 +647,11 @@ class MockMethod(object): # this call to the queue. # replay_mode: False if we are recording, True if we are verifying calls # against the call queue. + # method_to_mock: The actual method being mocked, used for introspection. method_name: str call_queue: list or deque replay_mode: bool + method_to_mock: a method object """ self._name = method_name @@ -566,6 +666,11 @@ class MockMethod(object): self._exception = None self._side_effects = None + try: + self._checker = MethodCallChecker(method_to_mock) + except ValueError: + self._checker = None + def __call__(self, *params, **named_params): """Log parameters and return the specified return value. @@ -583,6 +688,8 @@ class MockMethod(object): self._named_params = named_params if not self._replay_mode: + if self._checker is not None: + self._checker.Check(params, named_params) self._call_queue.append(self) return self diff --git a/mox_test.py b/mox_test.py index b626889..52b5905 100755 --- a/mox_test.py +++ b/mox_test.py @@ -534,6 +534,144 @@ class MockAnythingTest(unittest.TestCase): self.mock_object._Verify() +class MethodCheckerTest(unittest.TestCase): + """Tests MockMethod's use of MethodChecker method.""" + + def testNoParameters(self): + method = mox.MockMethod('NoParameters', [], False, + CheckCallTestClass.NoParameters) + method() + self.assertRaises(AttributeError, method, 1) + self.assertRaises(AttributeError, method, 1, 2) + self.assertRaises(AttributeError, method, a=1) + self.assertRaises(AttributeError, method, 1, b=2) + + def testOneParameter(self): + method = mox.MockMethod('OneParameter', [], False, + CheckCallTestClass.OneParameter) + self.assertRaises(AttributeError, method) + method(1) + method(a=1) + self.assertRaises(AttributeError, method, b=1) + self.assertRaises(AttributeError, method, 1, 2) + self.assertRaises(AttributeError, method, 1, a=2) + self.assertRaises(AttributeError, method, 1, b=2) + + def testTwoParameters(self): + method = mox.MockMethod('TwoParameters', [], False, + CheckCallTestClass.TwoParameters) + self.assertRaises(AttributeError, method) + self.assertRaises(AttributeError, method, 1) + self.assertRaises(AttributeError, method, a=1) + self.assertRaises(AttributeError, method, b=1) + method(1, 2) + method(1, b=2) + method(a=1, b=2) + method(b=2, a=1) + self.assertRaises(AttributeError, method, b=2, c=3) + self.assertRaises(AttributeError, method, a=1, b=2, c=3) + self.assertRaises(AttributeError, method, 1, 2, 3) + self.assertRaises(AttributeError, method, 1, 2, 3, 4) + self.assertRaises(AttributeError, method, 3, a=1, b=2) + + def testOneDefaultValue(self): + method = mox.MockMethod('OneDefaultValue', [], False, + CheckCallTestClass.OneDefaultValue) + method() + method(1) + method(a=1) + self.assertRaises(AttributeError, method, b=1) + self.assertRaises(AttributeError, method, 1, 2) + self.assertRaises(AttributeError, method, 1, a=2) + self.assertRaises(AttributeError, method, 1, b=2) + + def testTwoDefaultValues(self): + method = mox.MockMethod('TwoDefaultValues', [], False, + CheckCallTestClass.TwoDefaultValues) + self.assertRaises(AttributeError, method) + self.assertRaises(AttributeError, method, c=3) + self.assertRaises(AttributeError, method, 1) + self.assertRaises(AttributeError, method, 1, d=4) + self.assertRaises(AttributeError, method, 1, d=4, c=3) + method(1, 2) + method(a=1, b=2) + method(1, 2, 3) + method(1, 2, 3, 4) + method(1, 2, c=3) + method(1, 2, c=3, d=4) + method(1, 2, d=4, c=3) + method(d=4, c=3, a=1, b=2) + self.assertRaises(AttributeError, method, 1, 2, 3, 4, 5) + self.assertRaises(AttributeError, method, 1, 2, e=9) + self.assertRaises(AttributeError, method, a=1, b=2, e=9) + + def testArgs(self): + method = mox.MockMethod('Args', [], False, CheckCallTestClass.Args) + self.assertRaises(AttributeError, method) + self.assertRaises(AttributeError, method, 1) + method(1, 2) + method(a=1, b=2) + method(1, 2, 3) + method(1, 2, 3, 4) + self.assertRaises(AttributeError, method, 1, 2, a=3) + self.assertRaises(AttributeError, method, 1, 2, c=3) + + def testKwargs(self): + method = mox.MockMethod('Kwargs', [], False, CheckCallTestClass.Kwargs) + self.assertRaises(AttributeError, method) + method(1) + method(1, 2) + method(a=1, b=2) + method(b=2, a=1) + self.assertRaises(AttributeError, method, 1, 2, 3) + self.assertRaises(AttributeError, method, 1, 2, a=3) + method(1, 2, c=3) + method(a=1, b=2, c=3) + method(c=3, a=1, b=2) + method(a=1, b=2, c=3, d=4) + self.assertRaises(AttributeError, method, 1, 2, 3, 4) + + def testArgsAndKwargs(self): + method = mox.MockMethod('ArgsAndKwargs', [], False, + CheckCallTestClass.ArgsAndKwargs) + self.assertRaises(AttributeError, method) + method(1) + method(1, 2) + method(1, 2, 3) + method(a=1) + method(1, b=2) + self.assertRaises(AttributeError, method, 1, a=2) + method(b=2, a=1) + method(c=3, b=2, a=1) + method(1, 2, c=3) + + +class CheckCallTestClass(object): + def NoParameters(self): + pass + + def OneParameter(self, a): + pass + + def TwoParameters(self, a, b): + pass + + def OneDefaultValue(self, a=1): + pass + + def TwoDefaultValues(self, a, b, c=1, d=2): + pass + + def Args(self, a, b, *args): + pass + + def Kwargs(self, a, b=2, **kwargs): + pass + + def ArgsAndKwargs(self, a, *args, **kwargs): + pass + + class MockObjectTest(unittest.TestCase): """Verify that the MockObject class works as exepcted.""" -- cgit v1.2.1