summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsmiddlek <smiddlek@b1010a0a-674b-0410-b734-77272b80c875>2008-11-12 19:01:43 +0000
committersmiddlek <smiddlek@b1010a0a-674b-0410-b734-77272b80c875>2008-11-12 19:01:43 +0000
commit3394b66def83f9779150ba53f8d462300d7b279a (patch)
tree48fe326d2c57c5129fd52c20ca57b3905dadb33f
parent5ac2e500588d617b2b18777d202b5fd36081adb3 (diff)
downloadmox-3394b66def83f9779150ba53f8d462300d7b279a.tar.gz
Patch from Matt Brown (mdbrow@gmail.com) to inspect the arguments of mocked methods. Alright! :)
git-svn-id: http://pymox.googlecode.com/svn/trunk@26 b1010a0a-674b-0410-b734-77272b80c875
-rwxr-xr-xmox.py117
-rwxr-xr-xmox_test.py138
2 files changed, 250 insertions, 5 deletions
diff --git a/mox.py b/mox.py
index 2fddc0b..23348fd 100755
--- a/mox.py
+++ b/mox.py
@@ -57,6 +57,7 @@ Suggested usage / workflow:
"""
from collections import deque
+import inspect
import re
import types
import unittest
@@ -289,19 +290,21 @@ class MockAnything:
return self._CreateMockMethod(method_name)
- def _CreateMockMethod(self, method_name):
+ def _CreateMockMethod(self, method_name, method_to_mock=None):
"""Create a new mock method call and return it.
Args:
- # method name: the name of the method being called.
+ # method_name: the name of the method being called.
+ # method_to_mock: The actual method being mocked, used for introspection.
method_name: str
+ method_to_mock: a method object
Returns:
A new MockMethod aware of MockAnything's state (record or replay).
"""
return MockMethod(method_name, self._expected_calls_queue,
- self._replay_mode)
+ self._replay_mode, method_to_mock=method_to_mock)
def __nonzero__(self):
"""Return 1 for nonzero so the mock can be used as a conditional."""
@@ -409,7 +412,9 @@ class MockObject(MockAnything, object):
return getattr(self._class_to_mock, name)
if name in self._known_methods:
- return self._CreateMockMethod(name)
+ return self._CreateMockMethod(
+ name,
+ method_to_mock=getattr(self._class_to_mock, name))
raise UnknownMethodCallError(name)
@@ -531,6 +536,99 @@ class MockObject(MockAnything, object):
return self._class_to_mock
+class MethodCallChecker(object):
+ """Ensures that methods are called correctly."""
+
+ _NEEDED, _DEFAULT, _GIVEN = range(3)
+
+ def __init__(self, method):
+ """Creates a checker.
+
+ Args:
+ # method: A method to check.
+ method: function
+
+ Raises:
+ ValueError: method could not be inspected, so checks aren't possible.
+ Some methods and functions like built-ins can't be inspected.
+ """
+ try:
+ self._args, varargs, varkw, defaults = inspect.getargspec(method)
+ except TypeError:
+ raise ValueError('Could not get argument specification for %r'
+ % (method,))
+ if inspect.ismethod(method):
+ self._args = self._args[1:] # Skip 'self'.
+ self._method = method
+
+ self._has_varargs = varargs is not None
+ self._has_varkw = varkw is not None
+ if defaults is None:
+ self._required_args = self._args
+ self._default_args = []
+ else:
+ self._required_args = self._args[:-len(defaults)]
+ self._default_args = self._args[-len(defaults):]
+
+ def _RecordArgumentGiven(self, arg_name, arg_status):
+ """Mark an argument as being given.
+
+ Args:
+ # arg_name: The name of the argument to mark in arg_status.
+ # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN.
+ arg_name: string
+ arg_status: dict
+
+ Raises:
+ AttributeError: arg_name is already marked as _GIVEN.
+ """
+ if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN:
+ raise AttributeError('%s provided more than once' % (arg_name,))
+ arg_status[arg_name] = MethodCallChecker._GIVEN
+
+ def Check(self, params, named_params):
+ """Ensures that the parameters used while recording a call are valid.
+
+ Args:
+ # params: A list of positional parameters.
+ # named_params: A dict of named parameters.
+ params: list
+ named_params: dict
+
+ Raises:
+ AttributeError: the given parameters don't work with the given method.
+ """
+ arg_status = dict((a, MethodCallChecker._NEEDED)
+ for a in self._required_args)
+ for arg in self._default_args:
+ arg_status[arg] = MethodCallChecker._DEFAULT
+
+ # Check that each positional param is valid.
+ for i in range(len(params)):
+ try:
+ arg_name = self._args[i]
+ except IndexError:
+ if not self._has_varargs:
+ raise AttributeError('%s does not take %d or more positional '
+ 'arguments' % (self._method.__name__, i))
+ else:
+ self._RecordArgumentGiven(arg_name, arg_status)
+
+ # Check each keyword argument.
+ for arg_name in named_params:
+ if arg_name not in arg_status and not self._has_varkw:
+ raise AttributeError('%s is not expecting keyword argument %s'
+ % (self._method.__name__, arg_name))
+ self._RecordArgumentGiven(arg_name, arg_status)
+
+ # Ensure all the required arguments have been given.
+ still_needed = [k for k, v in arg_status.iteritems()
+ if v == MethodCallChecker._NEEDED]
+ if still_needed:
+ raise AttributeError('No values given for arguments %s'
+ % (' '.join(sorted(still_needed))))
+
+
class MockMethod(object):
"""Callable mock method.
@@ -540,7 +638,7 @@ class MockMethod(object):
signature) matches the expected method.
"""
- def __init__(self, method_name, call_queue, replay_mode):
+ def __init__(self, method_name, call_queue, replay_mode, method_to_mock=None):
"""Construct a new mock method.
Args:
@@ -549,9 +647,11 @@ class MockMethod(object):
# this call to the queue.
# replay_mode: False if we are recording, True if we are verifying calls
# against the call queue.
+ # method_to_mock: The actual method being mocked, used for introspection.
method_name: str
call_queue: list or deque
replay_mode: bool
+ method_to_mock: a method object
"""
self._name = method_name
@@ -566,6 +666,11 @@ class MockMethod(object):
self._exception = None
self._side_effects = None
+ try:
+ self._checker = MethodCallChecker(method_to_mock)
+ except ValueError:
+ self._checker = None
+
def __call__(self, *params, **named_params):
"""Log parameters and return the specified return value.
@@ -583,6 +688,8 @@ class MockMethod(object):
self._named_params = named_params
if not self._replay_mode:
+ if self._checker is not None:
+ self._checker.Check(params, named_params)
self._call_queue.append(self)
return self
diff --git a/mox_test.py b/mox_test.py
index b626889..52b5905 100755
--- a/mox_test.py
+++ b/mox_test.py
@@ -534,6 +534,144 @@ class MockAnythingTest(unittest.TestCase):
self.mock_object._Verify()
+class MethodCheckerTest(unittest.TestCase):
+ """Tests MockMethod's use of MethodChecker method."""
+
+ def testNoParameters(self):
+ method = mox.MockMethod('NoParameters', [], False,
+ CheckCallTestClass.NoParameters)
+ method()
+ self.assertRaises(AttributeError, method, 1)
+ self.assertRaises(AttributeError, method, 1, 2)
+ self.assertRaises(AttributeError, method, a=1)
+ self.assertRaises(AttributeError, method, 1, b=2)
+
+ def testOneParameter(self):
+ method = mox.MockMethod('OneParameter', [], False,
+ CheckCallTestClass.OneParameter)
+ self.assertRaises(AttributeError, method)
+ method(1)
+ method(a=1)
+ self.assertRaises(AttributeError, method, b=1)
+ self.assertRaises(AttributeError, method, 1, 2)
+ self.assertRaises(AttributeError, method, 1, a=2)
+ self.assertRaises(AttributeError, method, 1, b=2)
+
+ def testTwoParameters(self):
+ method = mox.MockMethod('TwoParameters', [], False,
+ CheckCallTestClass.TwoParameters)
+ self.assertRaises(AttributeError, method)
+ self.assertRaises(AttributeError, method, 1)
+ self.assertRaises(AttributeError, method, a=1)
+ self.assertRaises(AttributeError, method, b=1)
+ method(1, 2)
+ method(1, b=2)
+ method(a=1, b=2)
+ method(b=2, a=1)
+ self.assertRaises(AttributeError, method, b=2, c=3)
+ self.assertRaises(AttributeError, method, a=1, b=2, c=3)
+ self.assertRaises(AttributeError, method, 1, 2, 3)
+ self.assertRaises(AttributeError, method, 1, 2, 3, 4)
+ self.assertRaises(AttributeError, method, 3, a=1, b=2)
+
+ def testOneDefaultValue(self):
+ method = mox.MockMethod('OneDefaultValue', [], False,
+ CheckCallTestClass.OneDefaultValue)
+ method()
+ method(1)
+ method(a=1)
+ self.assertRaises(AttributeError, method, b=1)
+ self.assertRaises(AttributeError, method, 1, 2)
+ self.assertRaises(AttributeError, method, 1, a=2)
+ self.assertRaises(AttributeError, method, 1, b=2)
+
+ def testTwoDefaultValues(self):
+ method = mox.MockMethod('TwoDefaultValues', [], False,
+ CheckCallTestClass.TwoDefaultValues)
+ self.assertRaises(AttributeError, method)
+ self.assertRaises(AttributeError, method, c=3)
+ self.assertRaises(AttributeError, method, 1)
+ self.assertRaises(AttributeError, method, 1, d=4)
+ self.assertRaises(AttributeError, method, 1, d=4, c=3)
+ method(1, 2)
+ method(a=1, b=2)
+ method(1, 2, 3)
+ method(1, 2, 3, 4)
+ method(1, 2, c=3)
+ method(1, 2, c=3, d=4)
+ method(1, 2, d=4, c=3)
+ method(d=4, c=3, a=1, b=2)
+ self.assertRaises(AttributeError, method, 1, 2, 3, 4, 5)
+ self.assertRaises(AttributeError, method, 1, 2, e=9)
+ self.assertRaises(AttributeError, method, a=1, b=2, e=9)
+
+ def testArgs(self):
+ method = mox.MockMethod('Args', [], False, CheckCallTestClass.Args)
+ self.assertRaises(AttributeError, method)
+ self.assertRaises(AttributeError, method, 1)
+ method(1, 2)
+ method(a=1, b=2)
+ method(1, 2, 3)
+ method(1, 2, 3, 4)
+ self.assertRaises(AttributeError, method, 1, 2, a=3)
+ self.assertRaises(AttributeError, method, 1, 2, c=3)
+
+ def testKwargs(self):
+ method = mox.MockMethod('Kwargs', [], False, CheckCallTestClass.Kwargs)
+ self.assertRaises(AttributeError, method)
+ method(1)
+ method(1, 2)
+ method(a=1, b=2)
+ method(b=2, a=1)
+ self.assertRaises(AttributeError, method, 1, 2, 3)
+ self.assertRaises(AttributeError, method, 1, 2, a=3)
+ method(1, 2, c=3)
+ method(a=1, b=2, c=3)
+ method(c=3, a=1, b=2)
+ method(a=1, b=2, c=3, d=4)
+ self.assertRaises(AttributeError, method, 1, 2, 3, 4)
+
+ def testArgsAndKwargs(self):
+ method = mox.MockMethod('ArgsAndKwargs', [], False,
+ CheckCallTestClass.ArgsAndKwargs)
+ self.assertRaises(AttributeError, method)
+ method(1)
+ method(1, 2)
+ method(1, 2, 3)
+ method(a=1)
+ method(1, b=2)
+ self.assertRaises(AttributeError, method, 1, a=2)
+ method(b=2, a=1)
+ method(c=3, b=2, a=1)
+ method(1, 2, c=3)
+
+
+class CheckCallTestClass(object):
+ def NoParameters(self):
+ pass
+
+ def OneParameter(self, a):
+ pass
+
+ def TwoParameters(self, a, b):
+ pass
+
+ def OneDefaultValue(self, a=1):
+ pass
+
+ def TwoDefaultValues(self, a, b, c=1, d=2):
+ pass
+
+ def Args(self, a, b, *args):
+ pass
+
+ def Kwargs(self, a, b=2, **kwargs):
+ pass
+
+ def ArgsAndKwargs(self, a, *args, **kwargs):
+ pass
+
+
class MockObjectTest(unittest.TestCase):
"""Verify that the MockObject class works as exepcted."""