diff options
-rw-r--r-- | ChangeLog | 10 | ||||
-rw-r--r-- | astroid/as_string.py | 3 | ||||
-rw-r--r-- | astroid/brain/builtin_inference.py | 74 | ||||
-rw-r--r-- | astroid/exceptions.py | 4 | ||||
-rw-r--r-- | astroid/objects.py | 109 | ||||
-rw-r--r-- | astroid/tests/unittest_objects.py | 380 |
6 files changed, 578 insertions, 2 deletions
@@ -107,6 +107,16 @@ Change log for the astroid package (used to be astng) will not return an Instance of frozenset, without having access to its content, but a new objects.FrozenSet, which can be used just as a nodes.Set. + * Add a new *inference object* called Super, which also adds support for understanding + super calls. astroid understands the zero-argument form of super, specific to + Python 3, where the interpreter fills itself the arguments of the call. Also, we + are understanding the 2-argument form of super, both for bounded lookups + (super(X, instance)) as well as for unbounded lookups (super(X, Y)), + having as well support for validating that the object-or-type is a subtype + of the first argument. The unbounded form of super (one argument) is not + understood, since it's useless in practice and should be removed from + Python's specification. Closes issue #89. + 2015-03-14 -- 1.3.6 diff --git a/astroid/as_string.py b/astroid/as_string.py index 065d7c95..bcbc6f8c 100644 --- a/astroid/as_string.py +++ b/astroid/as_string.py @@ -443,6 +443,9 @@ class AsStringVisitor(object): def visit_frozenset(self, node): return node.parent.accept(self) + def visit_super(self, node): + return node.parent.accept(self) + class AsStringVisitor3k(AsStringVisitor): """AsStringVisitor3k overwrites some AsStringVisitor methods""" diff --git a/astroid/brain/builtin_inference.py b/astroid/brain/builtin_inference.py index 5d48eee8..6aff555f 100644 --- a/astroid/brain/builtin_inference.py +++ b/astroid/brain/builtin_inference.py @@ -52,7 +52,7 @@ def _extend_str(class_node, rvalue): def rstrip(self, chars=None): return {rvalue} def rjust(self, width, fillchar=None): - return {rvalue} + return {rvalue} def center(self, width, fillchar=None): return {rvalue} def ljust(self, width, fillchar=None): @@ -245,7 +245,79 @@ def infer_dict(node, context=None): empty.items = items return empty + +def _node_class(node): + klass = node.frame() + while klass is not None and not isinstance(klass, nodes.Class): + if klass.parent is None: + klass = None + else: + klass = klass.parent.frame() + return klass + + +def infer_super(node, context=None): + """Understand super calls. + + There are some restrictions for what can be understood: + + * unbounded super (one argument form) is not understood. + + * if the super call is not inside a function (classmethod or method), + then the default inference will be used. + + * if the super arguments can't be infered, the default inference + will be used. + """ + if len(node.args) == 1: + # Ignore unbounded super. + raise UseInferenceDefault + + scope = node.scope() + if not isinstance(scope, nodes.Function): + # Ignore non-method uses of super. + raise UseInferenceDefault + if scope.type not in ('classmethod', 'method'): + # Not interested in staticmethods. + raise UseInferenceDefault + + cls = _node_class(scope) + if not len(node.args): + mro_pointer = cls + # In we are in a classmethod, the interpreter will fill + # automatically the class as the second argument, not an instance. + if scope.type == 'classmethod': + mro_type = cls + else: + mro_type = cls.instanciate_class() + else: + # TODO(cpopa): support flow control (multiple inference values). + try: + mro_pointer = next(node.args[0].infer(context=context)) + except InferenceError: + raise UseInferenceDefault + try: + mro_type = next(node.args[1].infer(context=context)) + except InferenceError: + raise UseInferenceDefault + + if mro_pointer is YES or mro_type is YES: + # No way we could understand this. + raise UseInferenceDefault + + super_obj = objects.Super(mro_pointer=mro_pointer, + mro_type=mro_type, + self_class=cls) + super_obj.parent = node + return iter([super_obj]) + + # Builtins inference +MANAGER.register_transform(nodes.CallFunc, + inference_tip(infer_super), + lambda n: (isinstance(n.func, nodes.Name) and + n.func.name == 'super')) + register_builtin_transform(infer_tuple, 'tuple') register_builtin_transform(infer_set, 'set') register_builtin_transform(infer_list, 'list') diff --git a/astroid/exceptions.py b/astroid/exceptions.py index d1094cec..9e3a279f 100644 --- a/astroid/exceptions.py +++ b/astroid/exceptions.py @@ -42,6 +42,10 @@ class InconsistentMroError(MroError): """Error raised when a class's MRO is inconsistent.""" +class SuperError(ResolveError): + """Error raised when there is a problem with a super call.""" + + class NotFoundError(ResolveError): """raised when we are unable to resolve a name""" diff --git a/astroid/objects.py b/astroid/objects.py index ba60bba2..15991590 100644 --- a/astroid/objects.py +++ b/astroid/objects.py @@ -27,10 +27,16 @@ leads to an inferred FrozenSet: """ from logilab.common.decorators import cachedproperty +import six from astroid import MANAGER -from astroid.bases import BUILTINS, NodeNG, Instance +from astroid.bases import ( + BUILTINS, NodeNG, Instance, _infer_stmts, + BoundMethod, UnboundMethod, +) +from astroid.exceptions import SuperError, NotFoundError, MroError from astroid.node_classes import const_factory +from astroid.scoped_nodes import Class, Function from astroid.mixins import ParentAssignTypeMixin @@ -56,3 +62,104 @@ class FrozenSet(NodeNG, Instance, ParentAssignTypeMixin): def _proxied(self): builtins = MANAGER.astroid_cache[BUILTINS] return builtins.getattr('frozenset')[0] + + +class Super(NodeNG): + """Proxy class over a super call. + + This class offers almost the same behaviour as Python's super, + which is MRO lookups for retrieving attributes from the parents. + """ + + def __init__(self, mro_pointer, mro_type, self_class): + self.type = mro_type + self.mro_pointer = mro_pointer + self._class_based = False + self._self_class = self_class + self._model = { + '__thisclass__': self.mro_pointer, + '__self_class__': self._self_class, + '__self__': self.type, + '__class__': self._proxied, + } + + def _infer(self, context=None): + yield self + + def super_mro(self): + """Get the MRO which will be used to lookup attributes in this super.""" + if not isinstance(self.mro_pointer, Class): + raise SuperError("The first super argument must be type.") + + if isinstance(self.type, Class): + # `super(type, type)`, most likely in a class method. + self._class_based = True + mro_type = self.type + else: + mro_type = self.type._proxied + + if not mro_type.newstyle: + raise SuperError("Unable to call super on old-style classes.") + + mro = mro_type.mro() + if self.mro_pointer not in mro: + raise SuperError("super(type, obj): obj must be an instance " + "or subtype of type") + + index = mro.index(self.mro_pointer) + return mro[index + 1:] + + @cachedproperty + def _proxied(self): + builtins = MANAGER.astroid_cache[BUILTINS] + return builtins.getattr('super')[0] + + def pytype(self): + return '%s.super' % BUILTINS + + def display_type(self): + return 'Super of' + + @property + def name(self): + """Get the name of the MRO pointer.""" + return self.mro_pointer.name + + def igetattr(self, name, context=None): + """Retrieve the inferred values of the given attribute name.""" + + local_name = self._model.get(name) + if local_name: + yield local_name + return + + try: + mro = self.super_mro() + except (MroError, SuperError) as exc: + # Don't let invalid MROs or invalid super calls + # to leak out as is from this function. + six.raise_from(NotFoundError, exc) + + found = False + for cls in mro: + if name not in cls.locals: + continue + + found = True + for infered in _infer_stmts([cls[name]], context, frame=self): + if isinstance(infered, Function): + if self._class_based: + # The second argument to super is class, which + # means that we are returning unbound methods + # when accessing attributes. + yield UnboundMethod(infered) + else: + yield BoundMethod(infered, cls) + else: + yield infered + + if not found: + raise NotFoundError(name) + + def getattr(self, name, context=None): + return list(self.igetattr(name, context=context)) diff --git a/astroid/tests/unittest_objects.py b/astroid/tests/unittest_objects.py index 007e142c..91296b3b 100644 --- a/astroid/tests/unittest_objects.py +++ b/astroid/tests/unittest_objects.py @@ -19,6 +19,7 @@ import unittest
from astroid import bases
+from astroid import exceptions
from astroid import nodes
from astroid import objects
from astroid import test_utils
@@ -45,5 +46,384 @@ class ObjectsTest(unittest.TestCase): self.assertIsInstance(proxied, nodes.Class)
+class SuperTests(unittest.TestCase):
+
+ def test_inferring_super_outside_methods(self):
+ ast_nodes = test_utils.extract_node('''
+ class Module(object):
+ pass
+ class StaticMethod(object):
+ @staticmethod
+ def static():
+ # valid, but we don't bother with it.
+ return super(StaticMethod, StaticMethod) #@
+ # super outside methods aren't inferred
+ super(Module, Module) #@
+ # no argument super is not recognised outside methods as well.
+ super() #@
+ ''')
+ in_static = next(ast_nodes[0].value.infer())
+ self.assertIsInstance(in_static, bases.Instance)
+ self.assertEqual(in_static.qname(), "%s.super" % bases.BUILTINS)
+
+ module_level = next(ast_nodes[1].infer())
+ self.assertIsInstance(module_level, bases.Instance)
+ self.assertEqual(in_static.qname(), "%s.super" % bases.BUILTINS)
+
+ no_arguments = next(ast_nodes[2].infer())
+ self.assertIsInstance(no_arguments, bases.Instance)
+ self.assertEqual(no_arguments.qname(), "%s.super" % bases.BUILTINS)
+
+ def test_inferring_unbound_super_doesnt_work(self):
+ node = test_utils.extract_node('''
+ class Test(object):
+ def __init__(self):
+ super(Test) #@
+ ''')
+ unbounded = next(node.infer())
+ self.assertIsInstance(unbounded, bases.Instance)
+ self.assertEqual(unbounded.qname(), "%s.super" % bases.BUILTINS)
+
+ def test_use_default_inference_on_not_inferring_args(self):
+ ast_nodes = test_utils.extract_node('''
+ class Test(object):
+ def __init__(self):
+ super(Lala, self) #@
+ super(Test, lala) #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, bases.Instance)
+ self.assertEqual(first.qname(), "%s.super" % bases.BUILTINS)
+
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, bases.Instance)
+ self.assertEqual(second.qname(), "%s.super" % bases.BUILTINS)
+
+ @test_utils.require_version(maxver='3.0')
+ def test_super_on_old_style_class(self):
+ # super doesn't work on old style class, but leave
+ # that as an error for pylint. We'll infer Super objects,
+ # but every call will result in a failure at some point.
+ node = test_utils.extract_node('''
+ class OldStyle:
+ def __init__(self):
+ super(OldStyle, self) #@
+ ''')
+ old = next(node.infer())
+ self.assertIsInstance(old, objects.Super)
+ self.assertIsInstance(old.mro_pointer, nodes.Class)
+ self.assertEqual(old.mro_pointer.name, 'OldStyle')
+ with self.assertRaises(exceptions.SuperError) as cm:
+ old.super_mro()
+ self.assertEqual(str(cm.exception),
+ "Unable to call super on old-style classes.")
+
+ @test_utils.require_version(minver='3.0')
+ def test_no_arguments_super(self):
+ ast_nodes = test_utils.extract_node('''
+ class First(object): pass
+ class Second(First):
+ def test(self):
+ super() #@
+ @classmethod
+ def test_classmethod(cls):
+ super() #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, objects.Super)
+ self.assertIsInstance(first.type, bases.Instance)
+ self.assertEqual(first.type.name, 'Second')
+ self.assertIsInstance(first.mro_pointer, nodes.Class)
+ self.assertEqual(first.mro_pointer.name, 'Second')
+
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, objects.Super)
+ self.assertIsInstance(second.type, nodes.Class)
+ self.assertEqual(second.type.name, 'Second')
+ self.assertIsInstance(second.mro_pointer, nodes.Class)
+ self.assertEqual(second.mro_pointer.name, 'Second')
+
+ def test_super_simple_cases(self):
+ ast_nodes = test_utils.extract_node('''
+ class First(object): pass
+ class Second(First): pass
+ class Third(First):
+ def test(self):
+ super(Third, self) #@
+ super(Second, self) #@
+
+ # mro position and the type
+ super(Third, Third) #@
+ super(Third, Second) #@
+ super(Fourth, Fourth) #@
+
+ class Fourth(Third):
+ pass
+ ''')
+
+ # .type is the object which provides the mro.
+ # .mro_pointer is the position in the mro from where
+ # the lookup should be done.
+
+ # super(Third, self)
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, objects.Super)
+ self.assertIsInstance(first.type, bases.Instance)
+ self.assertEqual(first.type.name, 'Third')
+ self.assertIsInstance(first.mro_pointer, nodes.Class)
+ self.assertEqual(first.mro_pointer.name, 'Third')
+
+ # super(Second, self)
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, objects.Super)
+ self.assertIsInstance(second.type, bases.Instance)
+ self.assertEqual(second.type.name, 'Third')
+ self.assertIsInstance(first.mro_pointer, nodes.Class)
+ self.assertEqual(second.mro_pointer.name, 'Second')
+
+ # super(Third, Third)
+ third = next(ast_nodes[2].infer())
+ self.assertIsInstance(third, objects.Super)
+ self.assertIsInstance(third.type, nodes.Class)
+ self.assertEqual(third.type.name, 'Third')
+ self.assertIsInstance(third.mro_pointer, nodes.Class)
+ self.assertEqual(third.mro_pointer.name, 'Third')
+
+ # super(Third, second)
+ fourth = next(ast_nodes[3].infer())
+ self.assertIsInstance(fourth, objects.Super)
+ self.assertIsInstance(fourth.type, nodes.Class)
+ self.assertEqual(fourth.type.name, 'Second')
+ self.assertIsInstance(fourth.mro_pointer, nodes.Class)
+ self.assertEqual(fourth.mro_pointer.name, 'Third')
+
+ # Super(Fourth, Fourth)
+ fifth = next(ast_nodes[4].infer())
+ self.assertIsInstance(fifth, objects.Super)
+ self.assertIsInstance(fifth.type, nodes.Class)
+ self.assertEqual(fifth.type.name, 'Fourth')
+ self.assertIsInstance(fifth.mro_pointer, nodes.Class)
+ self.assertEqual(fifth.mro_pointer.name, 'Fourth')
+
+ def test_super_infer(self):
+ node = test_utils.extract_node('''
+ class Super(object):
+ def __init__(self):
+ super(Super, self) #@
+ ''')
+ inferred = next(node.infer())
+ self.assertIsInstance(inferred, objects.Super)
+ reinferred = next(inferred.infer())
+ self.assertIsInstance(reinferred, objects.Super)
+ self.assertIs(inferred, reinferred)
+
+ def test_inferring_invalid_supers(self):
+ ast_nodes = test_utils.extract_node('''
+ class Super(object):
+ def __init__(self):
+ # MRO pointer is not a type
+ super(1, self) #@
+ # MRO type is not a subtype
+ super(Super, 1) #@
+ # self is not a subtype of Bupper
+ super(Bupper, self) #@
+ class Bupper(Super):
+ pass
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, objects.Super)
+ with self.assertRaises(exceptions.SuperError) as cm:
+ first.super_mro()
+ self.assertEqual(str(cm.exception), "The first super argument must be type.")
+
+ for node in ast_nodes[1:]:
+ inferred = next(node.infer())
+ self.assertIsInstance(inferred, objects.Super, node)
+ with self.assertRaises(exceptions.SuperError) as cm:
+ inferred.super_mro()
+ self.assertEqual(str(cm.exception),
+ "super(type, obj): obj must be an instance "
+ "or subtype of type", node)
+
+ def test_proxied(self):
+ node = test_utils.extract_node('''
+ class Super(object):
+ def __init__(self):
+ super(Super, self) #@
+ ''')
+ infered = next(node.infer())
+ proxied = infered._proxied
+ self.assertEqual(proxied.qname(), "%s.super" % bases.BUILTINS)
+ self.assertIsInstance(proxied, nodes.Class)
+
+ def test_super_getattr_single_inheritance(self):
+ ast_nodes = test_utils.extract_node('''
+ class First(object):
+ def test(self): pass
+ class Second(First):
+ def test2(self): pass
+ class Third(Second):
+ test3 = 42
+ def __init__(self):
+ super(Third, self).test2 #@
+ super(Third, self).test #@
+ # test3 is local, no MRO lookup is done.
+ super(Third, self).test3 #@
+ super(Third, self) #@
+
+ # Unbounds.
+ super(Third, Third).test2 #@
+ super(Third, Third).test #@
+
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, bases.BoundMethod)
+ self.assertEqual(first.bound.name, 'Second')
+
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, bases.BoundMethod)
+ self.assertEqual(second.bound.name, 'First')
+
+ with self.assertRaises(exceptions.InferenceError):
+ next(ast_nodes[2].infer())
+ fourth = next(ast_nodes[3].infer())
+ with self.assertRaises(exceptions.NotFoundError):
+ fourth.getattr('test3')
+ with self.assertRaises(exceptions.NotFoundError):
+ next(fourth.igetattr('test3'))
+
+ first_unbound = next(ast_nodes[4].infer())
+ self.assertIsInstance(first_unbound, bases.UnboundMethod)
+ self.assertEqual(first_unbound._proxied.name, 'test2')
+ self.assertEqual(first_unbound._proxied.parent.name, 'Second')
+
+ second_unbound = next(ast_nodes[5].infer())
+ self.assertIsInstance(second_unbound, bases.UnboundMethod)
+ self.assertEqual(second_unbound._proxied.name, 'test')
+ self.assertEqual(second_unbound._proxied.parent.name, 'First')
+
+ def test_super_invalid_mro(self):
+ node = test_utils.extract_node('''
+ class A(object):
+ test = 42
+ class Super(A, A):
+ def __init__(self):
+ super(Super, self) #@
+ ''')
+ inferred = next(node.infer())
+ with self.assertRaises(exceptions.NotFoundError):
+ next(inferred.getattr('test'))
+
+ def test_super_complex_mro(self):
+ ast_nodes = test_utils.extract_node('''
+ class A(object):
+ def spam(self): return "A"
+ def foo(self): return "A"
+ class B(A):
+ def boo(self): return "B"
+ def spam(self): return "B"
+ class C(A):
+ def boo(self): return "C"
+ class E(C, B):
+ def __init__(self):
+ super(E, self).boo #@
+ super(C, self).boo #@
+ super(E, self).spam #@
+ super(E, self).foo #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, bases.BoundMethod)
+ self.assertEqual(first.bound.name, 'C')
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, bases.BoundMethod)
+ self.assertEqual(second.bound.name, 'B')
+ third = next(ast_nodes[2].infer())
+ self.assertIsInstance(third, bases.BoundMethod)
+ self.assertEqual(third.bound.name, 'B')
+ fourth = next(ast_nodes[3].infer())
+ self.assertEqual(fourth.bound.name, 'A')
+
+ def test_super_data_model(self):
+ ast_nodes = test_utils.extract_node('''
+ class X(object): pass
+ class A(X):
+ def __init__(self):
+ super(A, self) #@
+ super(A, A) #@
+ super(X, A) #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ thisclass = first.getattr('__thisclass__')[0]
+ self.assertIsInstance(thisclass, nodes.Class)
+ self.assertEqual(thisclass.name, 'A')
+ selfclass = first.getattr('__self_class__')[0]
+ self.assertIsInstance(selfclass, nodes.Class)
+ self.assertEqual(selfclass.name, 'A')
+ self_ = first.getattr('__self__')[0]
+ self.assertIsInstance(self_, bases.Instance)
+ self.assertEqual(self_.name, 'A')
+ cls = first.getattr('__class__')[0]
+ self.assertEqual(cls, first._proxied)
+
+ second = next(ast_nodes[1].infer())
+ thisclass = second.getattr('__thisclass__')[0]
+ self.assertEqual(thisclass.name, 'A')
+ self_ = second.getattr('__self__')[0]
+ self.assertIsInstance(self_, nodes.Class)
+ self.assertEqual(self_.name, 'A')
+
+ third = next(ast_nodes[2].infer())
+ thisclass = third.getattr('__thisclass__')[0]
+ self.assertEqual(thisclass.name, 'X')
+ selfclass = third.getattr('__self_class__')[0]
+ self.assertEqual(selfclass.name, 'A')
+
+ def assertEqualMro(self, klass, expected_mro):
+ self.assertEqual(
+ [member.name for member in klass.super_mro()],
+ expected_mro)
+
+ def test_super_mro(self):
+ ast_nodes = test_utils.extract_node('''
+ class A(object): pass
+ class B(A): pass
+ class C(A): pass
+ class E(C, B):
+ def __init__(self):
+ super(E, self) #@
+ super(C, self) #@
+ super(B, self) #@
+
+ super(B, 1) #@
+ super(1, B) #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertEqualMro(first, ['C', 'B', 'A', 'object'])
+ second = next(ast_nodes[1].infer())
+ self.assertEqualMro(second, ['B', 'A', 'object'])
+ third = next(ast_nodes[2].infer())
+ self.assertEqualMro(third, ['A', 'object'])
+
+ fourth = next(ast_nodes[3].infer())
+ with self.assertRaises(exceptions.SuperError):
+ fourth.super_mro()
+ fifth = next(ast_nodes[4].infer())
+ with self.assertRaises(exceptions.SuperError):
+ fifth.super_mro()
+
+ def test_super_yes_objects(self):
+ ast_nodes = test_utils.extract_node('''
+ from collections import Missing
+ class A(object):
+ def __init__(self):
+ super(Missing, self) #@
+ super(A, Missing) #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, bases.Instance)
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, bases.Instance)
+
+
if __name__ == '__main__':
unittest.main()
|