summaryrefslogtreecommitdiff
path: root/mox.py
diff options
context:
space:
mode:
Diffstat (limited to 'mox.py')
-rwxr-xr-xmox.py117
1 files changed, 112 insertions, 5 deletions
diff --git a/mox.py b/mox.py
index 2fddc0b..23348fd 100755
--- a/mox.py
+++ b/mox.py
@@ -57,6 +57,7 @@ Suggested usage / workflow:
"""
from collections import deque
+import inspect
import re
import types
import unittest
@@ -289,19 +290,21 @@ class MockAnything:
return self._CreateMockMethod(method_name)
- def _CreateMockMethod(self, method_name):
+ def _CreateMockMethod(self, method_name, method_to_mock=None):
"""Create a new mock method call and return it.
Args:
- # method name: the name of the method being called.
+ # method_name: the name of the method being called.
+ # method_to_mock: The actual method being mocked, used for introspection.
method_name: str
+ method_to_mock: a method object
Returns:
A new MockMethod aware of MockAnything's state (record or replay).
"""
return MockMethod(method_name, self._expected_calls_queue,
- self._replay_mode)
+ self._replay_mode, method_to_mock=method_to_mock)
def __nonzero__(self):
"""Return 1 for nonzero so the mock can be used as a conditional."""
@@ -409,7 +412,9 @@ class MockObject(MockAnything, object):
return getattr(self._class_to_mock, name)
if name in self._known_methods:
- return self._CreateMockMethod(name)
+ return self._CreateMockMethod(
+ name,
+ method_to_mock=getattr(self._class_to_mock, name))
raise UnknownMethodCallError(name)
@@ -531,6 +536,99 @@ class MockObject(MockAnything, object):
return self._class_to_mock
+class MethodCallChecker(object):
+ """Ensures that methods are called correctly."""
+
+ _NEEDED, _DEFAULT, _GIVEN = range(3)
+
+ def __init__(self, method):
+ """Creates a checker.
+
+ Args:
+ # method: A method to check.
+ method: function
+
+ Raises:
+ ValueError: method could not be inspected, so checks aren't possible.
+ Some methods and functions like built-ins can't be inspected.
+ """
+ try:
+ self._args, varargs, varkw, defaults = inspect.getargspec(method)
+ except TypeError:
+ raise ValueError('Could not get argument specification for %r'
+ % (method,))
+ if inspect.ismethod(method):
+ self._args = self._args[1:] # Skip 'self'.
+ self._method = method
+
+ self._has_varargs = varargs is not None
+ self._has_varkw = varkw is not None
+ if defaults is None:
+ self._required_args = self._args
+ self._default_args = []
+ else:
+ self._required_args = self._args[:-len(defaults)]
+ self._default_args = self._args[-len(defaults):]
+
+ def _RecordArgumentGiven(self, arg_name, arg_status):
+ """Mark an argument as being given.
+
+ Args:
+ # arg_name: The name of the argument to mark in arg_status.
+ # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN.
+ arg_name: string
+ arg_status: dict
+
+ Raises:
+ AttributeError: arg_name is already marked as _GIVEN.
+ """
+ if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN:
+ raise AttributeError('%s provided more than once' % (arg_name,))
+ arg_status[arg_name] = MethodCallChecker._GIVEN
+
+ def Check(self, params, named_params):
+ """Ensures that the parameters used while recording a call are valid.
+
+ Args:
+ # params: A list of positional parameters.
+ # named_params: A dict of named parameters.
+ params: list
+ named_params: dict
+
+ Raises:
+ AttributeError: the given parameters don't work with the given method.
+ """
+ arg_status = dict((a, MethodCallChecker._NEEDED)
+ for a in self._required_args)
+ for arg in self._default_args:
+ arg_status[arg] = MethodCallChecker._DEFAULT
+
+ # Check that each positional param is valid.
+ for i in range(len(params)):
+ try:
+ arg_name = self._args[i]
+ except IndexError:
+ if not self._has_varargs:
+ raise AttributeError('%s does not take %d or more positional '
+ 'arguments' % (self._method.__name__, i))
+ else:
+ self._RecordArgumentGiven(arg_name, arg_status)
+
+ # Check each keyword argument.
+ for arg_name in named_params:
+ if arg_name not in arg_status and not self._has_varkw:
+ raise AttributeError('%s is not expecting keyword argument %s'
+ % (self._method.__name__, arg_name))
+ self._RecordArgumentGiven(arg_name, arg_status)
+
+ # Ensure all the required arguments have been given.
+ still_needed = [k for k, v in arg_status.iteritems()
+ if v == MethodCallChecker._NEEDED]
+ if still_needed:
+ raise AttributeError('No values given for arguments %s'
+ % (' '.join(sorted(still_needed))))
+
+
class MockMethod(object):
"""Callable mock method.
@@ -540,7 +638,7 @@ class MockMethod(object):
signature) matches the expected method.
"""
- def __init__(self, method_name, call_queue, replay_mode):
+ def __init__(self, method_name, call_queue, replay_mode, method_to_mock=None):
"""Construct a new mock method.
Args:
@@ -549,9 +647,11 @@ class MockMethod(object):
# this call to the queue.
# replay_mode: False if we are recording, True if we are verifying calls
# against the call queue.
+ # method_to_mock: The actual method being mocked, used for introspection.
method_name: str
call_queue: list or deque
replay_mode: bool
+ method_to_mock: a method object
"""
self._name = method_name
@@ -566,6 +666,11 @@ class MockMethod(object):
self._exception = None
self._side_effects = None
+ try:
+ self._checker = MethodCallChecker(method_to_mock)
+ except ValueError:
+ self._checker = None
+
def __call__(self, *params, **named_params):
"""Log parameters and return the specified return value.
@@ -583,6 +688,8 @@ class MockMethod(object):
self._named_params = named_params
if not self._replay_mode:
+ if self._checker is not None:
+ self._checker.Check(params, named_params)
self._call_queue.append(self)
return self