summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <cpopa@cloudbasesolutions.com>2015-08-21 00:54:39 +0300
committerClaudiu Popa <cpopa@cloudbasesolutions.com>2015-08-21 00:54:39 +0300
commite5633fa2ad81df237b5a4b64a09904ae7ae3895e (patch)
tree36e6e6a3fb2d9f4424f985746a6aef827f989ede
parent267ed99da4f0482d4a964fc98dc6ad14d4bf77cb (diff)
downloadastroid-git-e5633fa2ad81df237b5a4b64a09904ae7ae3895e.tar.gz
Understand slices of tuples, lists, strings and instances with support for slices.
Closes issue #137.
-rw-r--r--ChangeLog5
-rw-r--r--astroid/inference.py88
-rw-r--r--astroid/node_classes.py15
-rw-r--r--astroid/objects.py1
-rw-r--r--astroid/tests/unittest_inference.py115
5 files changed, 202 insertions, 22 deletions
diff --git a/ChangeLog b/ChangeLog
index dc5445e3..2794e548 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -278,6 +278,11 @@ Change log for the astroid package (used to be astng)
of a class, which will also do an evaluation of what declared_metaclass
returns.
+ * Understand slices of tuples, lists, strings and instances with support
+ for slices.
+
+ Closes issue #137.
+
2015-03-14 -- 1.3.6
diff --git a/astroid/inference.py b/astroid/inference.py
index 7235c13e..b3659a0e 100644
--- a/astroid/inference.py
+++ b/astroid/inference.py
@@ -22,6 +22,8 @@ import functools
import itertools
import operator
+import six
+
from astroid import bases
from astroid import context as contextmod
from astroid import exceptions
@@ -115,6 +117,13 @@ nodes.CallFunc._infer = infer_callfunc
@bases.path_wrapper
+def infer_slice(self, context=None):
+ yield self
+
+nodes.Slice._infer = infer_slice
+
+
+@bases.path_wrapper
def infer_import(self, context=None, asname=True):
"""infer an Import node: return the imported module/object"""
name = context.lookupname
@@ -186,8 +195,43 @@ def infer_global(self, context=None):
nodes.Global._infer = infer_global
+_SLICE_SENTINEL = object()
+
+def _slice_value(index, context=None):
+ """Get the value of the given slice index."""
+ if isinstance(index, nodes.Const):
+ if isinstance(index.value, (int, type(None))):
+ return index.value
+ elif index is None:
+ return None
+ else:
+ # Try to infer what the index actually is.
+ # Since we can't return all the possible values,
+ # we'll stop at the first possible value.
+ try:
+ inferred = next(index.infer(context=context))
+ except exceptions.InferenceError:
+ pass
+ else:
+ if isinstance(inferred, nodes.Const):
+ if isinstance(inferred.value, (int, type(None))):
+ return inferred.value
+
+ # Use a sentinel, because None can be a valid
+ # value that this function can return,
+ # as it is the case for unspecified bounds.
+ return _SLICE_SENTINEL
+
+
def infer_subscript(self, context=None):
- """infer simple subscription such as [1,2,3][0] or (1,2,3)[-1]"""
+ """Inference for subscripts
+
+ We're understanding if the index is a Const
+ or a slice, passing the result of inference
+ to the value's `getitem` method, which should
+ handle each supported index type accordingly.
+ """
+
value = next(self.value.infer(context))
if value is util.YES:
yield util.YES
@@ -198,24 +242,35 @@ def infer_subscript(self, context=None):
yield util.YES
return
+ index_value = _SLICE_SENTINEL
if isinstance(index, nodes.Const):
- try:
- assigned = value.getitem(index.value, context)
- except AttributeError:
- raise exceptions.InferenceError()
- except (IndexError, TypeError):
- yield util.YES
- return
-
- # Prevent inferring if the infered subscript
- # is the same as the original subscripted object.
- if self is assigned or assigned is util.YES:
- yield util.YES
- return
- for infered in assigned.infer(context):
- yield infered
+ index_value = index.value
+ elif isinstance(index, nodes.Slice):
+ # Infer slices from the original object.
+ lower = _slice_value(index.lower, context)
+ upper = _slice_value(index.upper, context)
+ step = _slice_value(index.step, context)
+ if all(elem is not _SLICE_SENTINEL for elem in (lower, upper, step)):
+ index_value = slice(lower, upper, step)
else:
raise exceptions.InferenceError()
+
+ if index_value is _SLICE_SENTINEL:
+ raise exceptions.InferenceError
+
+ try:
+ assigned = value.getitem(index_value, context)
+ except (IndexError, TypeError, AttributeError) as exc:
+ six.raise_from(exceptions.InferenceError, exc)
+
+ # Prevent inferring if the infered subscript
+ # is the same as the original subscripted object.
+ if self is assigned or assigned is util.YES:
+ yield util.YES
+ return
+ for infered in assigned.infer(context):
+ yield infered
+
nodes.Subscript._infer = bases.path_wrapper(infer_subscript)
nodes.Subscript.infer_lhs = bases.raise_if_nothing_infered(infer_subscript)
@@ -638,7 +693,6 @@ nodes.Index._infer = infer_index
def instance_getitem(self, index, context=None):
# Rewrap index to Const for this case
index = nodes.Const(index)
-
if context:
new_context = context.clone()
else:
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index 944598a6..5ea755e2 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -109,6 +109,17 @@ def are_exclusive(stmt1, stmt2, exceptions=None):
return False
+def _container_getitem(instance, elts, index):
+ """Get a slice or an item, using the given *index*, for the given sequence."""
+ if isinstance(index, slice):
+ new_cls = instance.__class__()
+ new_cls.elts = elts[index]
+ new_cls.parent = instance.parent
+ return new_cls
+ else:
+ return elts[index]
+
+
@six.add_metaclass(abc.ABCMeta)
class _BaseContainer(mixins.ParentAssignTypeMixin,
bases.NodeNG,
@@ -816,7 +827,7 @@ class List(_BaseContainer):
return '%s.list' % BUILTINS
def getitem(self, index, context=None):
- return self.elts[index]
+ return _container_getitem(self, self.elts, index)
class Nonlocal(bases.Statement):
@@ -940,7 +951,7 @@ class Tuple(_BaseContainer):
return '%s.tuple' % BUILTINS
def getitem(self, index, context=None):
- return self.elts[index]
+ return _container_getitem(self, self.elts, index)
class UnaryOp(bases.NodeNG):
diff --git a/astroid/objects.py b/astroid/objects.py
index aa3848a3..c7d85d82 100644
--- a/astroid/objects.py
+++ b/astroid/objects.py
@@ -32,7 +32,6 @@ from astroid import bases
from astroid import decorators
from astroid import exceptions
from astroid import MANAGER
-from astroid import mixins
from astroid import node_classes
from astroid import scoped_nodes
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index 20666f58..02048947 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -699,13 +699,17 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
NoGetitem()[4] #@
InvalidGetitem()[5] #@
InvalidGetitem2()[10] #@
- [1, 2, 3][None] #@
- 'lala'['bala'] #@
''')
for node in ast_nodes[:3]:
self.assertRaises(InferenceError, next, node.infer())
for node in ast_nodes[3:]:
self.assertEqual(next(node.infer()), util.YES)
+ ast_nodes = test_utils.extract_node('''
+ [1, 2, 3][None] #@
+ 'lala'['bala'] #@
+ ''')
+ for node in ast_nodes:
+ self.assertRaises(InferenceError, next, node.infer())
def test_bytes_subscript(self):
node = test_utils.extract_node('''b'a'[0]''')
@@ -2692,6 +2696,113 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
self.assertIsInstance(inferred, nodes.Const)
self.assertEqual(inferred.value, 42)
+ def _slicing_test_helper(self, pairs, cls, get_elts):
+ for code, expected in pairs:
+ ast_node = test_utils.extract_node(code)
+ inferred = next(ast_node.infer())
+ self.assertIsInstance(inferred, cls)
+ self.assertEqual(get_elts(inferred), expected,
+ ast_node.as_string())
+
+ def test_slicing_list(self):
+ pairs = (
+ ("[1, 2, 3][:] #@", [1, 2, 3]),
+ ("[1, 2, 3][0:] #@", [1, 2, 3]),
+ ("[1, 2, 3][None:] #@", [1, 2, 3]),
+ ("[1, 2, 3][None:None] #@", [1, 2, 3]),
+ ("[1, 2, 3][0:-1] #@", [1, 2]),
+ ("[1, 2, 3][0:2] #@", [1, 2]),
+ ("[1, 2, 3][0:2:None] #@", [1, 2]),
+ ("[1, 2, 3][::] #@", [1, 2, 3]),
+ ("[1, 2, 3][::2] #@", [1, 3]),
+ ("[1, 2, 3][::-1] #@", [3, 2, 1]),
+ ("[1, 2, 3][0:2:2] #@", [1]),
+ ("[1, 2, 3, 4, 5, 6][0:4-1:2+0] #@", [1, 3]),
+ )
+ self._slicing_test_helper(
+ pairs, nodes.List,
+ lambda inferred: [elt.value for elt in inferred.elts])
+
+ def test_slicing_tuple(self):
+ pairs = (
+ ("(1, 2, 3)[:] #@", [1, 2, 3]),
+ ("(1, 2, 3)[0:] #@", [1, 2, 3]),
+ ("(1, 2, 3)[None:] #@", [1, 2, 3]),
+ ("(1, 2, 3)[None:None] #@", [1, 2, 3]),
+ ("(1, 2, 3)[0:-1] #@", [1, 2]),
+ ("(1, 2, 3)[0:2] #@", [1, 2]),
+ ("(1, 2, 3)[0:2:None] #@", [1, 2]),
+ ("(1, 2, 3)[::] #@", [1, 2, 3]),
+ ("(1, 2, 3)[::2] #@", [1, 3]),
+ ("(1, 2, 3)[::-1] #@", [3, 2, 1]),
+ ("(1, 2, 3)[0:2:2] #@", [1]),
+ ("(1, 2, 3, 4, 5, 6)[0:4-1:2+0] #@", [1, 3]),
+ )
+ self._slicing_test_helper(
+ pairs, nodes.Tuple,
+ lambda inferred: [elt.value for elt in inferred.elts])
+
+ def test_slicing_str(self):
+ pairs = (
+ ("'123'[:] #@", "123"),
+ ("'123'[0:] #@", "123"),
+ ("'123'[None:] #@", "123"),
+ ("'123'[None:None] #@", "123"),
+ ("'123'[0:-1] #@", "12"),
+ ("'123'[0:2] #@", "12"),
+ ("'123'[0:2:None] #@", "12"),
+ ("'123'[::] #@", "123"),
+ ("'123'[::2] #@", "13"),
+ ("'123'[::-1] #@", "321"),
+ ("'123'[0:2:2] #@", "1"),
+ ("'123456'[0:4-1:2+0] #@", "13"),
+ )
+ self._slicing_test_helper(
+ pairs, nodes.Const, lambda inferred: inferred.value)
+
+ def test_invalid_slicing_primaries(self):
+ examples = [
+ "(lambda x: x)[1:2]",
+ "1[2]",
+ "enumerate[2]",
+ "(1, 2, 3)[a:]",
+ "(1, 2, 3)[object:object]",
+ "(1, 2, 3)[1:object]",
+ ]
+ for code in examples:
+ node = test_utils.extract_node(code)
+ self.assertRaises(InferenceError, next, node.infer())
+
+ def test_instance_slicing(self):
+ ast_nodes = test_utils.extract_node('''
+ class A(object):
+ def __getitem__(self, index):
+ return [1, 2, 3, 4, 5][index]
+ A()[1:] #@
+ A()[:2] #@
+ A()[1:4] #@
+ ''')
+ expected_values = [
+ [2, 3, 4, 5],
+ [1, 2],
+ [2, 3, 4],
+ ]
+ for expected, node in zip(expected_values, ast_nodes):
+ inferred = next(node.infer())
+ self.assertIsInstance(inferred, nodes.List)
+ self.assertEqual([elt.value for elt in inferred.elts], expected)
+
+ def test_instance_slicing_fails(self):
+ ast_nodes = test_utils.extract_node('''
+ class A(object):
+ def __getitem__(self, index):
+ return 1[index]
+ A()[4:5] #@
+ A()[2:] #@
+ ''')
+ for node in ast_nodes:
+ self.assertEqual(next(node.infer()), util.YES)
+
class GetattrTest(unittest.TestCase):