summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <cpopa@cloudbasesolutions.com>2015-06-26 22:50:55 +0300
committerClaudiu Popa <cpopa@cloudbasesolutions.com>2015-06-26 22:50:55 +0300
commit024109e7c955265bd6e1841063f30192ab5de0d9 (patch)
treeb7eeda0d1ec1e2be4908b1b7989d96f255033d46
parentd683e22f663d568b034214810aa93b038ba6dee2 (diff)
downloadastroid-024109e7c955265bd6e1841063f30192ab5de0d9.tar.gz
Add helpers.is_supertype and helpers.is_subtype, two functions for checking if an object is a super/sub type of another.
-rw-r--r--ChangeLog3
-rw-r--r--astroid/helpers.py93
-rw-r--r--astroid/tests/unittest_helpers.py123
-rw-r--r--pylintrc2
4 files changed, 187 insertions, 34 deletions
diff --git a/ChangeLog b/ChangeLog
index af8eacc..c14cce6 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -188,6 +188,9 @@ Change log for the astroid package (used to be astng)
This uses the recently added *astroid.helpers.object_type* in order to
retrieve the Python type of the first argument of the call.
+ * Add helpers.is_supertype and helpers.is_subtype, two functions for
+ checking if an object is a super/sub type of another.
+
2015-03-14 -- 1.3.6
diff --git a/astroid/helpers.py b/astroid/helpers.py
index 5996c7c..fa941d9 100644
--- a/astroid/helpers.py
+++ b/astroid/helpers.py
@@ -1,19 +1,19 @@
-# copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
-# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
-#
-# This file is part of astroid.
-#
-# astroid is free software: you can redistribute it and/or modify it
-# under the terms of the GNU Lesser General Public License as published by the
-# Free Software Foundation, either version 2.1 of the License, or (at your
-# option) any later version.
-#
-# astroid is distributed in the hope that it will be useful, but
-# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
-# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
-# for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License along
+# copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of astroid.
+#
+# astroid is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published by the
+# Free Software Foundation, either version 2.1 of the License, or (at your
+# option) any later version.
+#
+# astroid is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
+# for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
# with astroid. If not, see <http://www.gnu.org/licenses/>.
"""
@@ -93,3 +93,64 @@ def object_type(node, context=None):
if len(types) > 1:
return bases.YES
return list(types)[0]
+
+
+def safe_infer(node, context=None):
+ """Return the inferred value for the given node.
+
+ Return None if inference failed or if there is some ambiguity (more than
+ one node has been inferred).
+ """
+ try:
+ inferit = node.infer(context=context)
+ value = next(inferit)
+ except exceptions.InferenceError:
+ return
+ try:
+ next(inferit)
+ return # None if there is ambiguity on the inferred node
+ except exceptions.InferenceError:
+ return # there is some kind of ambiguity
+ except StopIteration:
+ return value
+
+
+def has_known_bases(klass, context=None):
+ """Return true if all base classes of a class could be inferred."""
+ try:
+ return klass._all_bases_known
+ except AttributeError:
+ pass
+ for base in klass.bases:
+ result = safe_infer(base, context=context)
+ # TODO: check for A->B->A->B pattern in class structure too?
+ if (not isinstance(result, scoped_nodes.Class) or
+ result is klass or
+ not has_known_bases(result, context=context)):
+ klass._all_bases_known = False
+ return False
+ klass._all_bases_known = True
+ return True
+
+
+def _type_check(type1, type2):
+ if not all(map(has_known_bases, (type1, type2))):
+ return bases.YES
+
+ if not all([type1.newstyle, type2.newstyle]):
+ return False
+ try:
+ return type1 in type2.mro()[:-1]
+ except exceptions.MroError:
+ # The MRO is invalid.
+ return bases.YES
+
+
+def is_subtype(type1, type2):
+ """Check if *type1* is a subtype of *typ2*."""
+ return _type_check(type2, type1)
+
+
+def is_supertype(type1, type2):
+ """Check if *type2* is a supertype of *type1*."""
+ return _type_check(type1, type2)
diff --git a/astroid/tests/unittest_helpers.py b/astroid/tests/unittest_helpers.py
index 7645420..9ba3e6b 100644
--- a/astroid/tests/unittest_helpers.py
+++ b/astroid/tests/unittest_helpers.py
@@ -1,19 +1,19 @@
-# copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
-# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
-#
-# This file is part of astroid.
-#
-# astroid is free software: you can redistribute it and/or modify it
-# under the terms of the GNU Lesser General Public License as published by the
-# Free Software Foundation, either version 2.1 of the License, or (at your
-# option) any later version.
-#
-# astroid is distributed in the hope that it will be useful, but
-# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
-# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
-# for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License along
+# copyright 2003-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of astroid.
+#
+# astroid is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published by the
+# Free Software Foundation, either version 2.1 of the License, or (at your
+# option) any later version.
+#
+# astroid is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
+# for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
# with astroid. If not, see <http://www.gnu.org/licenses/>.
import unittest
@@ -155,7 +155,7 @@ class TestHelpers(unittest.TestCase):
''')
self.assertEqual(helpers.object_type(node), bases.YES)
- def test_too_many_types(self):
+ def test_object_type_too_many_types(self):
node = test_utils.extract_node('''
from unknown import Unknown
def test(x):
@@ -167,6 +167,95 @@ class TestHelpers(unittest.TestCase):
''')
self.assertEqual(helpers.object_type(node), bases.YES)
+ def test_is_subtype(self):
+ ast_nodes = test_utils.extract_node('''
+ class int_subclass(int):
+ pass
+ class A(object): pass #@
+ class B(A): pass #@
+ class C(A): pass #@
+ int_subclass() #@
+ ''')
+ cls_a = ast_nodes[0]
+ cls_b = ast_nodes[1]
+ cls_c = ast_nodes[2]
+ int_subclass = ast_nodes[3]
+ int_subclass = helpers.object_type(next(int_subclass.infer()))
+ base_int = self._extract('int')
+ self.assertTrue(helpers.is_subtype(int_subclass, base_int))
+ self.assertTrue(helpers.is_supertype(base_int, int_subclass))
+
+ self.assertTrue(helpers.is_supertype(cls_a, cls_b))
+ self.assertTrue(helpers.is_supertype(cls_a, cls_c))
+ self.assertTrue(helpers.is_subtype(cls_b, cls_a))
+ self.assertTrue(helpers.is_subtype(cls_c, cls_a))
+ self.assertFalse(helpers.is_subtype(cls_a, cls_b))
+ self.assertFalse(helpers.is_subtype(cls_a, cls_b))
+
+ @test_utils.require_version(maxver='3.0')
+ def test_is_subtype_supertype_old_style_classes(self):
+ cls_a, cls_b = test_utils.extract_node('''
+ class A: #@
+ pass
+ class B(A): #@
+ pass
+ ''')
+ self.assertFalse(helpers.is_subtype(cls_a, cls_b))
+ self.assertFalse(helpers.is_subtype(cls_b, cls_a))
+ self.assertFalse(helpers.is_supertype(cls_a, cls_b))
+ self.assertFalse(helpers.is_supertype(cls_b, cls_a))
+
+ def test_is_subtype_supertype_mro_error(self):
+ cls_e, cls_f = test_utils.extract_node('''
+ class A(object): pass
+ class B(A): pass
+ class C(A): pass
+ class D(B, C): pass
+ class E(C, B): pass #@
+ class F(D, E): pass #@
+ ''')
+ self.assertFalse(helpers.is_subtype(cls_e, cls_f))
+ self.assertEqual(helpers.is_subtype(cls_f, cls_e), bases.YES)
+ self.assertEqual(helpers.is_supertype(cls_e, cls_f), bases.YES)
+ self.assertFalse(helpers.is_supertype(cls_f, cls_e))
+
+ def test_is_subtype_supertype_unknown_bases(self):
+ cls_a, cls_b = test_utils.extract_node('''
+ from unknown import Unknown
+ class A(Unknown): pass #@
+ class B(A): pass #@
+ ''')
+ self.assertTrue(helpers.is_subtype(cls_b, cls_a))
+ self.assertTrue(helpers.is_supertype(cls_a, cls_b))
+
+ def test_is_subtype_supertype_unrelated_classes(self):
+ cls_a, cls_b = test_utils.extract_node('''
+ class A(object): pass #@
+ class B(object): pass #@
+ ''')
+ self.assertFalse(helpers.is_subtype(cls_a, cls_b))
+ self.assertFalse(helpers.is_subtype(cls_b, cls_a))
+ self.assertFalse(helpers.is_supertype(cls_a, cls_b))
+ self.assertFalse(helpers.is_supertype(cls_b, cls_a))
+
+ def test_is_subtype_supertype_classes_no_type_ancestor(self):
+ cls_a = test_utils.extract_node('''
+ class A(object): #@
+ pass
+ ''')
+ builtin_type = self._extract('type')
+ self.assertFalse(helpers.is_supertype(builtin_type, cls_a))
+ self.assertFalse(helpers.is_subtype(cls_a, builtin_type))
+
+ def test_is_subtype_supertype_classes_metaclasses(self):
+ cls_a = test_utils.extract_node('''
+ class A(type): #@
+ pass
+ ''')
+ builtin_type = self._extract('type')
+ self.assertTrue(helpers.is_supertype(builtin_type, cls_a))
+ self.assertTrue(helpers.is_subtype(cls_a, builtin_type))
+
if __name__ == '__main__':
unittest.main()
diff --git a/pylintrc b/pylintrc
index ac9c835..f3ad435 100644
--- a/pylintrc
+++ b/pylintrc
@@ -98,7 +98,7 @@ disable=invalid-name,protected-access,no-self-use,unused-argument,
too-many-return-statements,redefined-outer-name,undefined-variable,
too-many-locals,method-hidden,duplicate-code,attribute-defined-outside-init,
fixme,missing-docstring,too-many-lines,too-many-statements,undefined-loop-variable,
- unpacking-non-sequence,import-error,no-name-in-module
+ unpacking-non-sequence,import-error,no-name-in-module,bad-builtin
[BASIC]