summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsmiddlek <smiddlek@b1010a0a-674b-0410-b734-77272b80c875>2010-08-18 21:18:37 +0000
committersmiddlek <smiddlek@b1010a0a-674b-0410-b734-77272b80c875>2010-08-18 21:18:37 +0000
commit994820b0d3ac0b0f93249177ab90010d26d3c292 (patch)
tree440f7c5d7f3ad97401d87533016bc45a4c277f48
parent4821b508508bc99a47eb2826bd1eacd94f1274ed (diff)
downloadmox-994820b0d3ac0b0f93249177ab90010d26d3c292.tar.gz
Fix for issue #27 by stevepm@google.com / steve.middlekauff@gmail.com
git-svn-id: http://pymox.googlecode.com/svn/trunk@55 b1010a0a-674b-0410-b734-77272b80c875
-rwxr-xr-xmox.py33
-rwxr-xr-xmox_test.py15
-rwxr-xr-xmox_test_helper.py12
3 files changed, 56 insertions, 4 deletions
diff --git a/mox.py b/mox.py
index 14ea121..a4a4fb5 100755
--- a/mox.py
+++ b/mox.py
@@ -899,8 +899,17 @@ class MethodSignatureChecker(object):
if inspect.ismethod(self._method):
# The extra param accounts for the bound instance.
if len(params) == len(self._args) + 1:
- clazz = getattr(self._method, 'im_class', None)
- if isinstance(params[0], clazz) or params[0] == clazz:
+ expected = getattr(self._method, 'im_class', None)
+
+ # Check if the param is an instance of the expected class,
+ # or check equality (useful for checking Comparators).
+ if isinstance(params[0], expected) or params[0] == expected:
+ params = params[1:]
+ # If the IsA() comparator is being used, we need to check the
+ # inverse of the usual case - that the given instance is a subclass
+ # of the expected class. For example, the code under test does
+ # late binding to a subclass.
+ elif isinstance(params[0], IsA) and params[0]._IsSubClass(expected):
params = params[1:]
# Check that each positional param is valid.
@@ -1284,8 +1293,26 @@ class IsA(Comparator):
# things like cStringIO.StringIO.
return type(rhs) == type(self._class_name)
+ def _IsSubClass(self, clazz):
+ """Check to see if the IsA comparators class is a subclass of clazz.
+
+ Args:
+ # clazz: a class object
+
+ Returns:
+ bool
+ """
+
+ try:
+ return issubclass(self._class_name, clazz)
+ except TypeError:
+ # Check raw types if there was a type error. This is helpful for
+ # things like cStringIO.StringIO.
+ return type(clazz) == type(self._class_name)
+
def __repr__(self):
- return str(self._class_name)
+ return 'mox.IsA(%s) ' % str(self._class_name)
+
class IsAlmost(Comparator):
"""Comparison class used to check whether a parameter is nearly equal
diff --git a/mox_test.py b/mox_test.py
index 611f8fb..bb684de 100755
--- a/mox_test.py
+++ b/mox_test.py
@@ -1575,6 +1575,20 @@ class MoxTest(unittest.TestCase):
self.mox.UnsetStubs()
self.assertEquals('foo', actual)
+ def testStubOutMethod_CalledAsUnboundMethod_Subclass_Comparator(self):
+ print 'this test'
+ self.mox.StubOutWithMock(mox_test_helper.TestClassFromAnotherModule, 'Value')
+ mox_test_helper.TestClassFromAnotherModule.Value(
+ mox.IsA(mox_test_helper.ChildClassFromAnotherModule)).AndReturn('foo')
+ self.mox.ReplayAll()
+
+ instance = mox_test_helper.ChildClassFromAnotherModule()
+ actual = mox_test_helper.TestClassFromAnotherModule.Value(instance)
+
+ self.mox.VerifyAll()
+ self.mox.UnsetStubs()
+ self.assertEquals('foo', actual)
+
def testStubOutMethod_CalledAsUnboundMethod_ActualInstance(self):
instance = TestClass()
self.mox.StubOutWithMock(TestClass, 'OtherValidCall')
@@ -1599,7 +1613,6 @@ class MoxTest(unittest.TestCase):
self.assertRaises(mox.UnexpectedMethodCallError,
TestClass.OtherValidCall, "wrong self")
-
self.mox.VerifyAll()
self.mox.UnsetStubs()
diff --git a/mox_test_helper.py b/mox_test_helper.py
index 60f72aa..0ccd484 100755
--- a/mox_test_helper.py
+++ b/mox_test_helper.py
@@ -95,6 +95,17 @@ class TestClassFromAnotherModule(object):
return 'Not mock'
+class ChildClassFromAnotherModule(TestClassFromAnotherModule):
+ """A child class of TestClassFromAnotherModule.
+
+ Used to test stubbing out unbound methods, where child classes
+ are eventually bound.
+ """
+
+ def __init__(self):
+ TestClassFromAnotherModule.__init__(self)
+
+
class CallableClass(object):
def __init__(self, one, two, nine=None):
@@ -110,6 +121,7 @@ class CallableClass(object):
def MyTestFunction(one, two, nine=None):
pass
+
class ExampleClass(object):
def TestMethod(self, one, two, nine=None):
pass