summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog10
-rw-r--r--astroid/as_string.py3
-rw-r--r--astroid/brain/builtin_inference.py74
-rw-r--r--astroid/exceptions.py4
-rw-r--r--astroid/objects.py109
-rw-r--r--astroid/tests/unittest_objects.py380
6 files changed, 578 insertions, 2 deletions
diff --git a/ChangeLog b/ChangeLog
index 84d06b4e..b08d0f6c 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -107,6 +107,16 @@ Change log for the astroid package (used to be astng)
will not return an Instance of frozenset, without having access to its
content, but a new objects.FrozenSet, which can be used just as a nodes.Set.
+ * Add a new *inference object* called Super, which also adds support for understanding
+ super calls. astroid understands the zero-argument form of super, specific to
+ Python 3, where the interpreter fills itself the arguments of the call. Also, we
+ are understanding the 2-argument form of super, both for bounded lookups
+ (super(X, instance)) as well as for unbounded lookups (super(X, Y)),
+ having as well support for validating that the object-or-type is a subtype
+ of the first argument. The unbounded form of super (one argument) is not
+ understood, since it's useless in practice and should be removed from
+ Python's specification. Closes issue #89.
+
2015-03-14 -- 1.3.6
diff --git a/astroid/as_string.py b/astroid/as_string.py
index 065d7c95..bcbc6f8c 100644
--- a/astroid/as_string.py
+++ b/astroid/as_string.py
@@ -443,6 +443,9 @@ class AsStringVisitor(object):
def visit_frozenset(self, node):
return node.parent.accept(self)
+ def visit_super(self, node):
+ return node.parent.accept(self)
+
class AsStringVisitor3k(AsStringVisitor):
"""AsStringVisitor3k overwrites some AsStringVisitor methods"""
diff --git a/astroid/brain/builtin_inference.py b/astroid/brain/builtin_inference.py
index 5d48eee8..6aff555f 100644
--- a/astroid/brain/builtin_inference.py
+++ b/astroid/brain/builtin_inference.py
@@ -52,7 +52,7 @@ def _extend_str(class_node, rvalue):
def rstrip(self, chars=None):
return {rvalue}
def rjust(self, width, fillchar=None):
- return {rvalue}
+ return {rvalue}
def center(self, width, fillchar=None):
return {rvalue}
def ljust(self, width, fillchar=None):
@@ -245,7 +245,79 @@ def infer_dict(node, context=None):
empty.items = items
return empty
+
+def _node_class(node):
+ klass = node.frame()
+ while klass is not None and not isinstance(klass, nodes.Class):
+ if klass.parent is None:
+ klass = None
+ else:
+ klass = klass.parent.frame()
+ return klass
+
+
+def infer_super(node, context=None):
+ """Understand super calls.
+
+ There are some restrictions for what can be understood:
+
+ * unbounded super (one argument form) is not understood.
+
+ * if the super call is not inside a function (classmethod or method),
+ then the default inference will be used.
+
+ * if the super arguments can't be infered, the default inference
+ will be used.
+ """
+ if len(node.args) == 1:
+ # Ignore unbounded super.
+ raise UseInferenceDefault
+
+ scope = node.scope()
+ if not isinstance(scope, nodes.Function):
+ # Ignore non-method uses of super.
+ raise UseInferenceDefault
+ if scope.type not in ('classmethod', 'method'):
+ # Not interested in staticmethods.
+ raise UseInferenceDefault
+
+ cls = _node_class(scope)
+ if not len(node.args):
+ mro_pointer = cls
+ # In we are in a classmethod, the interpreter will fill
+ # automatically the class as the second argument, not an instance.
+ if scope.type == 'classmethod':
+ mro_type = cls
+ else:
+ mro_type = cls.instanciate_class()
+ else:
+ # TODO(cpopa): support flow control (multiple inference values).
+ try:
+ mro_pointer = next(node.args[0].infer(context=context))
+ except InferenceError:
+ raise UseInferenceDefault
+ try:
+ mro_type = next(node.args[1].infer(context=context))
+ except InferenceError:
+ raise UseInferenceDefault
+
+ if mro_pointer is YES or mro_type is YES:
+ # No way we could understand this.
+ raise UseInferenceDefault
+
+ super_obj = objects.Super(mro_pointer=mro_pointer,
+ mro_type=mro_type,
+ self_class=cls)
+ super_obj.parent = node
+ return iter([super_obj])
+
+
# Builtins inference
+MANAGER.register_transform(nodes.CallFunc,
+ inference_tip(infer_super),
+ lambda n: (isinstance(n.func, nodes.Name) and
+ n.func.name == 'super'))
+
register_builtin_transform(infer_tuple, 'tuple')
register_builtin_transform(infer_set, 'set')
register_builtin_transform(infer_list, 'list')
diff --git a/astroid/exceptions.py b/astroid/exceptions.py
index d1094cec..9e3a279f 100644
--- a/astroid/exceptions.py
+++ b/astroid/exceptions.py
@@ -42,6 +42,10 @@ class InconsistentMroError(MroError):
"""Error raised when a class's MRO is inconsistent."""
+class SuperError(ResolveError):
+ """Error raised when there is a problem with a super call."""
+
+
class NotFoundError(ResolveError):
"""raised when we are unable to resolve a name"""
diff --git a/astroid/objects.py b/astroid/objects.py
index ba60bba2..15991590 100644
--- a/astroid/objects.py
+++ b/astroid/objects.py
@@ -27,10 +27,16 @@ leads to an inferred FrozenSet:
"""
from logilab.common.decorators import cachedproperty
+import six
from astroid import MANAGER
-from astroid.bases import BUILTINS, NodeNG, Instance
+from astroid.bases import (
+ BUILTINS, NodeNG, Instance, _infer_stmts,
+ BoundMethod, UnboundMethod,
+)
+from astroid.exceptions import SuperError, NotFoundError, MroError
from astroid.node_classes import const_factory
+from astroid.scoped_nodes import Class, Function
from astroid.mixins import ParentAssignTypeMixin
@@ -56,3 +62,104 @@ class FrozenSet(NodeNG, Instance, ParentAssignTypeMixin):
def _proxied(self):
builtins = MANAGER.astroid_cache[BUILTINS]
return builtins.getattr('frozenset')[0]
+
+
+class Super(NodeNG):
+ """Proxy class over a super call.
+
+ This class offers almost the same behaviour as Python's super,
+ which is MRO lookups for retrieving attributes from the parents.
+ """
+
+ def __init__(self, mro_pointer, mro_type, self_class):
+ self.type = mro_type
+ self.mro_pointer = mro_pointer
+ self._class_based = False
+ self._self_class = self_class
+ self._model = {
+ '__thisclass__': self.mro_pointer,
+ '__self_class__': self._self_class,
+ '__self__': self.type,
+ '__class__': self._proxied,
+ }
+
+ def _infer(self, context=None):
+ yield self
+
+ def super_mro(self):
+ """Get the MRO which will be used to lookup attributes in this super."""
+ if not isinstance(self.mro_pointer, Class):
+ raise SuperError("The first super argument must be type.")
+
+ if isinstance(self.type, Class):
+ # `super(type, type)`, most likely in a class method.
+ self._class_based = True
+ mro_type = self.type
+ else:
+ mro_type = self.type._proxied
+
+ if not mro_type.newstyle:
+ raise SuperError("Unable to call super on old-style classes.")
+
+ mro = mro_type.mro()
+ if self.mro_pointer not in mro:
+ raise SuperError("super(type, obj): obj must be an instance "
+ "or subtype of type")
+
+ index = mro.index(self.mro_pointer)
+ return mro[index + 1:]
+
+ @cachedproperty
+ def _proxied(self):
+ builtins = MANAGER.astroid_cache[BUILTINS]
+ return builtins.getattr('super')[0]
+
+ def pytype(self):
+ return '%s.super' % BUILTINS
+
+ def display_type(self):
+ return 'Super of'
+
+ @property
+ def name(self):
+ """Get the name of the MRO pointer."""
+ return self.mro_pointer.name
+
+ def igetattr(self, name, context=None):
+ """Retrieve the inferred values of the given attribute name."""
+
+ local_name = self._model.get(name)
+ if local_name:
+ yield local_name
+ return
+
+ try:
+ mro = self.super_mro()
+ except (MroError, SuperError) as exc:
+ # Don't let invalid MROs or invalid super calls
+ # to leak out as is from this function.
+ six.raise_from(NotFoundError, exc)
+
+ found = False
+ for cls in mro:
+ if name not in cls.locals:
+ continue
+
+ found = True
+ for infered in _infer_stmts([cls[name]], context, frame=self):
+ if isinstance(infered, Function):
+ if self._class_based:
+ # The second argument to super is class, which
+ # means that we are returning unbound methods
+ # when accessing attributes.
+ yield UnboundMethod(infered)
+ else:
+ yield BoundMethod(infered, cls)
+ else:
+ yield infered
+
+ if not found:
+ raise NotFoundError(name)
+
+ def getattr(self, name, context=None):
+ return list(self.igetattr(name, context=context))
diff --git a/astroid/tests/unittest_objects.py b/astroid/tests/unittest_objects.py
index 007e142c..91296b3b 100644
--- a/astroid/tests/unittest_objects.py
+++ b/astroid/tests/unittest_objects.py
@@ -19,6 +19,7 @@
import unittest
from astroid import bases
+from astroid import exceptions
from astroid import nodes
from astroid import objects
from astroid import test_utils
@@ -45,5 +46,384 @@ class ObjectsTest(unittest.TestCase):
self.assertIsInstance(proxied, nodes.Class)
+class SuperTests(unittest.TestCase):
+
+ def test_inferring_super_outside_methods(self):
+ ast_nodes = test_utils.extract_node('''
+ class Module(object):
+ pass
+ class StaticMethod(object):
+ @staticmethod
+ def static():
+ # valid, but we don't bother with it.
+ return super(StaticMethod, StaticMethod) #@
+ # super outside methods aren't inferred
+ super(Module, Module) #@
+ # no argument super is not recognised outside methods as well.
+ super() #@
+ ''')
+ in_static = next(ast_nodes[0].value.infer())
+ self.assertIsInstance(in_static, bases.Instance)
+ self.assertEqual(in_static.qname(), "%s.super" % bases.BUILTINS)
+
+ module_level = next(ast_nodes[1].infer())
+ self.assertIsInstance(module_level, bases.Instance)
+ self.assertEqual(in_static.qname(), "%s.super" % bases.BUILTINS)
+
+ no_arguments = next(ast_nodes[2].infer())
+ self.assertIsInstance(no_arguments, bases.Instance)
+ self.assertEqual(no_arguments.qname(), "%s.super" % bases.BUILTINS)
+
+ def test_inferring_unbound_super_doesnt_work(self):
+ node = test_utils.extract_node('''
+ class Test(object):
+ def __init__(self):
+ super(Test) #@
+ ''')
+ unbounded = next(node.infer())
+ self.assertIsInstance(unbounded, bases.Instance)
+ self.assertEqual(unbounded.qname(), "%s.super" % bases.BUILTINS)
+
+ def test_use_default_inference_on_not_inferring_args(self):
+ ast_nodes = test_utils.extract_node('''
+ class Test(object):
+ def __init__(self):
+ super(Lala, self) #@
+ super(Test, lala) #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, bases.Instance)
+ self.assertEqual(first.qname(), "%s.super" % bases.BUILTINS)
+
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, bases.Instance)
+ self.assertEqual(second.qname(), "%s.super" % bases.BUILTINS)
+
+ @test_utils.require_version(maxver='3.0')
+ def test_super_on_old_style_class(self):
+ # super doesn't work on old style class, but leave
+ # that as an error for pylint. We'll infer Super objects,
+ # but every call will result in a failure at some point.
+ node = test_utils.extract_node('''
+ class OldStyle:
+ def __init__(self):
+ super(OldStyle, self) #@
+ ''')
+ old = next(node.infer())
+ self.assertIsInstance(old, objects.Super)
+ self.assertIsInstance(old.mro_pointer, nodes.Class)
+ self.assertEqual(old.mro_pointer.name, 'OldStyle')
+ with self.assertRaises(exceptions.SuperError) as cm:
+ old.super_mro()
+ self.assertEqual(str(cm.exception),
+ "Unable to call super on old-style classes.")
+
+ @test_utils.require_version(minver='3.0')
+ def test_no_arguments_super(self):
+ ast_nodes = test_utils.extract_node('''
+ class First(object): pass
+ class Second(First):
+ def test(self):
+ super() #@
+ @classmethod
+ def test_classmethod(cls):
+ super() #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, objects.Super)
+ self.assertIsInstance(first.type, bases.Instance)
+ self.assertEqual(first.type.name, 'Second')
+ self.assertIsInstance(first.mro_pointer, nodes.Class)
+ self.assertEqual(first.mro_pointer.name, 'Second')
+
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, objects.Super)
+ self.assertIsInstance(second.type, nodes.Class)
+ self.assertEqual(second.type.name, 'Second')
+ self.assertIsInstance(second.mro_pointer, nodes.Class)
+ self.assertEqual(second.mro_pointer.name, 'Second')
+
+ def test_super_simple_cases(self):
+ ast_nodes = test_utils.extract_node('''
+ class First(object): pass
+ class Second(First): pass
+ class Third(First):
+ def test(self):
+ super(Third, self) #@
+ super(Second, self) #@
+
+ # mro position and the type
+ super(Third, Third) #@
+ super(Third, Second) #@
+ super(Fourth, Fourth) #@
+
+ class Fourth(Third):
+ pass
+ ''')
+
+ # .type is the object which provides the mro.
+ # .mro_pointer is the position in the mro from where
+ # the lookup should be done.
+
+ # super(Third, self)
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, objects.Super)
+ self.assertIsInstance(first.type, bases.Instance)
+ self.assertEqual(first.type.name, 'Third')
+ self.assertIsInstance(first.mro_pointer, nodes.Class)
+ self.assertEqual(first.mro_pointer.name, 'Third')
+
+ # super(Second, self)
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, objects.Super)
+ self.assertIsInstance(second.type, bases.Instance)
+ self.assertEqual(second.type.name, 'Third')
+ self.assertIsInstance(first.mro_pointer, nodes.Class)
+ self.assertEqual(second.mro_pointer.name, 'Second')
+
+ # super(Third, Third)
+ third = next(ast_nodes[2].infer())
+ self.assertIsInstance(third, objects.Super)
+ self.assertIsInstance(third.type, nodes.Class)
+ self.assertEqual(third.type.name, 'Third')
+ self.assertIsInstance(third.mro_pointer, nodes.Class)
+ self.assertEqual(third.mro_pointer.name, 'Third')
+
+ # super(Third, second)
+ fourth = next(ast_nodes[3].infer())
+ self.assertIsInstance(fourth, objects.Super)
+ self.assertIsInstance(fourth.type, nodes.Class)
+ self.assertEqual(fourth.type.name, 'Second')
+ self.assertIsInstance(fourth.mro_pointer, nodes.Class)
+ self.assertEqual(fourth.mro_pointer.name, 'Third')
+
+ # Super(Fourth, Fourth)
+ fifth = next(ast_nodes[4].infer())
+ self.assertIsInstance(fifth, objects.Super)
+ self.assertIsInstance(fifth.type, nodes.Class)
+ self.assertEqual(fifth.type.name, 'Fourth')
+ self.assertIsInstance(fifth.mro_pointer, nodes.Class)
+ self.assertEqual(fifth.mro_pointer.name, 'Fourth')
+
+ def test_super_infer(self):
+ node = test_utils.extract_node('''
+ class Super(object):
+ def __init__(self):
+ super(Super, self) #@
+ ''')
+ inferred = next(node.infer())
+ self.assertIsInstance(inferred, objects.Super)
+ reinferred = next(inferred.infer())
+ self.assertIsInstance(reinferred, objects.Super)
+ self.assertIs(inferred, reinferred)
+
+ def test_inferring_invalid_supers(self):
+ ast_nodes = test_utils.extract_node('''
+ class Super(object):
+ def __init__(self):
+ # MRO pointer is not a type
+ super(1, self) #@
+ # MRO type is not a subtype
+ super(Super, 1) #@
+ # self is not a subtype of Bupper
+ super(Bupper, self) #@
+ class Bupper(Super):
+ pass
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, objects.Super)
+ with self.assertRaises(exceptions.SuperError) as cm:
+ first.super_mro()
+ self.assertEqual(str(cm.exception), "The first super argument must be type.")
+
+ for node in ast_nodes[1:]:
+ inferred = next(node.infer())
+ self.assertIsInstance(inferred, objects.Super, node)
+ with self.assertRaises(exceptions.SuperError) as cm:
+ inferred.super_mro()
+ self.assertEqual(str(cm.exception),
+ "super(type, obj): obj must be an instance "
+ "or subtype of type", node)
+
+ def test_proxied(self):
+ node = test_utils.extract_node('''
+ class Super(object):
+ def __init__(self):
+ super(Super, self) #@
+ ''')
+ infered = next(node.infer())
+ proxied = infered._proxied
+ self.assertEqual(proxied.qname(), "%s.super" % bases.BUILTINS)
+ self.assertIsInstance(proxied, nodes.Class)
+
+ def test_super_getattr_single_inheritance(self):
+ ast_nodes = test_utils.extract_node('''
+ class First(object):
+ def test(self): pass
+ class Second(First):
+ def test2(self): pass
+ class Third(Second):
+ test3 = 42
+ def __init__(self):
+ super(Third, self).test2 #@
+ super(Third, self).test #@
+ # test3 is local, no MRO lookup is done.
+ super(Third, self).test3 #@
+ super(Third, self) #@
+
+ # Unbounds.
+ super(Third, Third).test2 #@
+ super(Third, Third).test #@
+
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, bases.BoundMethod)
+ self.assertEqual(first.bound.name, 'Second')
+
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, bases.BoundMethod)
+ self.assertEqual(second.bound.name, 'First')
+
+ with self.assertRaises(exceptions.InferenceError):
+ next(ast_nodes[2].infer())
+ fourth = next(ast_nodes[3].infer())
+ with self.assertRaises(exceptions.NotFoundError):
+ fourth.getattr('test3')
+ with self.assertRaises(exceptions.NotFoundError):
+ next(fourth.igetattr('test3'))
+
+ first_unbound = next(ast_nodes[4].infer())
+ self.assertIsInstance(first_unbound, bases.UnboundMethod)
+ self.assertEqual(first_unbound._proxied.name, 'test2')
+ self.assertEqual(first_unbound._proxied.parent.name, 'Second')
+
+ second_unbound = next(ast_nodes[5].infer())
+ self.assertIsInstance(second_unbound, bases.UnboundMethod)
+ self.assertEqual(second_unbound._proxied.name, 'test')
+ self.assertEqual(second_unbound._proxied.parent.name, 'First')
+
+ def test_super_invalid_mro(self):
+ node = test_utils.extract_node('''
+ class A(object):
+ test = 42
+ class Super(A, A):
+ def __init__(self):
+ super(Super, self) #@
+ ''')
+ inferred = next(node.infer())
+ with self.assertRaises(exceptions.NotFoundError):
+ next(inferred.getattr('test'))
+
+ def test_super_complex_mro(self):
+ ast_nodes = test_utils.extract_node('''
+ class A(object):
+ def spam(self): return "A"
+ def foo(self): return "A"
+ class B(A):
+ def boo(self): return "B"
+ def spam(self): return "B"
+ class C(A):
+ def boo(self): return "C"
+ class E(C, B):
+ def __init__(self):
+ super(E, self).boo #@
+ super(C, self).boo #@
+ super(E, self).spam #@
+ super(E, self).foo #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, bases.BoundMethod)
+ self.assertEqual(first.bound.name, 'C')
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, bases.BoundMethod)
+ self.assertEqual(second.bound.name, 'B')
+ third = next(ast_nodes[2].infer())
+ self.assertIsInstance(third, bases.BoundMethod)
+ self.assertEqual(third.bound.name, 'B')
+ fourth = next(ast_nodes[3].infer())
+ self.assertEqual(fourth.bound.name, 'A')
+
+ def test_super_data_model(self):
+ ast_nodes = test_utils.extract_node('''
+ class X(object): pass
+ class A(X):
+ def __init__(self):
+ super(A, self) #@
+ super(A, A) #@
+ super(X, A) #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ thisclass = first.getattr('__thisclass__')[0]
+ self.assertIsInstance(thisclass, nodes.Class)
+ self.assertEqual(thisclass.name, 'A')
+ selfclass = first.getattr('__self_class__')[0]
+ self.assertIsInstance(selfclass, nodes.Class)
+ self.assertEqual(selfclass.name, 'A')
+ self_ = first.getattr('__self__')[0]
+ self.assertIsInstance(self_, bases.Instance)
+ self.assertEqual(self_.name, 'A')
+ cls = first.getattr('__class__')[0]
+ self.assertEqual(cls, first._proxied)
+
+ second = next(ast_nodes[1].infer())
+ thisclass = second.getattr('__thisclass__')[0]
+ self.assertEqual(thisclass.name, 'A')
+ self_ = second.getattr('__self__')[0]
+ self.assertIsInstance(self_, nodes.Class)
+ self.assertEqual(self_.name, 'A')
+
+ third = next(ast_nodes[2].infer())
+ thisclass = third.getattr('__thisclass__')[0]
+ self.assertEqual(thisclass.name, 'X')
+ selfclass = third.getattr('__self_class__')[0]
+ self.assertEqual(selfclass.name, 'A')
+
+ def assertEqualMro(self, klass, expected_mro):
+ self.assertEqual(
+ [member.name for member in klass.super_mro()],
+ expected_mro)
+
+ def test_super_mro(self):
+ ast_nodes = test_utils.extract_node('''
+ class A(object): pass
+ class B(A): pass
+ class C(A): pass
+ class E(C, B):
+ def __init__(self):
+ super(E, self) #@
+ super(C, self) #@
+ super(B, self) #@
+
+ super(B, 1) #@
+ super(1, B) #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertEqualMro(first, ['C', 'B', 'A', 'object'])
+ second = next(ast_nodes[1].infer())
+ self.assertEqualMro(second, ['B', 'A', 'object'])
+ third = next(ast_nodes[2].infer())
+ self.assertEqualMro(third, ['A', 'object'])
+
+ fourth = next(ast_nodes[3].infer())
+ with self.assertRaises(exceptions.SuperError):
+ fourth.super_mro()
+ fifth = next(ast_nodes[4].infer())
+ with self.assertRaises(exceptions.SuperError):
+ fifth.super_mro()
+
+ def test_super_yes_objects(self):
+ ast_nodes = test_utils.extract_node('''
+ from collections import Missing
+ class A(object):
+ def __init__(self):
+ super(Missing, self) #@
+ super(A, Missing) #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, bases.Instance)
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, bases.Instance)
+
+
if __name__ == '__main__':
unittest.main()