diff options
author | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2015-08-21 00:54:39 +0300 |
---|---|---|
committer | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2015-08-21 00:54:39 +0300 |
commit | e5633fa2ad81df237b5a4b64a09904ae7ae3895e (patch) | |
tree | 36e6e6a3fb2d9f4424f985746a6aef827f989ede | |
parent | 267ed99da4f0482d4a964fc98dc6ad14d4bf77cb (diff) | |
download | astroid-git-e5633fa2ad81df237b5a4b64a09904ae7ae3895e.tar.gz |
Understand slices of tuples, lists, strings and instances with support for slices.
Closes issue #137.
-rw-r--r-- | ChangeLog | 5 | ||||
-rw-r--r-- | astroid/inference.py | 88 | ||||
-rw-r--r-- | astroid/node_classes.py | 15 | ||||
-rw-r--r-- | astroid/objects.py | 1 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 115 |
5 files changed, 202 insertions, 22 deletions
@@ -278,6 +278,11 @@ Change log for the astroid package (used to be astng) of a class, which will also do an evaluation of what declared_metaclass returns. + * Understand slices of tuples, lists, strings and instances with support + for slices. + + Closes issue #137. + 2015-03-14 -- 1.3.6 diff --git a/astroid/inference.py b/astroid/inference.py index 7235c13e..b3659a0e 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -22,6 +22,8 @@ import functools import itertools import operator +import six + from astroid import bases from astroid import context as contextmod from astroid import exceptions @@ -115,6 +117,13 @@ nodes.CallFunc._infer = infer_callfunc @bases.path_wrapper +def infer_slice(self, context=None): + yield self + +nodes.Slice._infer = infer_slice + + +@bases.path_wrapper def infer_import(self, context=None, asname=True): """infer an Import node: return the imported module/object""" name = context.lookupname @@ -186,8 +195,43 @@ def infer_global(self, context=None): nodes.Global._infer = infer_global +_SLICE_SENTINEL = object() + +def _slice_value(index, context=None): + """Get the value of the given slice index.""" + if isinstance(index, nodes.Const): + if isinstance(index.value, (int, type(None))): + return index.value + elif index is None: + return None + else: + # Try to infer what the index actually is. + # Since we can't return all the possible values, + # we'll stop at the first possible value. + try: + inferred = next(index.infer(context=context)) + except exceptions.InferenceError: + pass + else: + if isinstance(inferred, nodes.Const): + if isinstance(inferred.value, (int, type(None))): + return inferred.value + + # Use a sentinel, because None can be a valid + # value that this function can return, + # as it is the case for unspecified bounds. + return _SLICE_SENTINEL + + def infer_subscript(self, context=None): - """infer simple subscription such as [1,2,3][0] or (1,2,3)[-1]""" + """Inference for subscripts + + We're understanding if the index is a Const + or a slice, passing the result of inference + to the value's `getitem` method, which should + handle each supported index type accordingly. + """ + value = next(self.value.infer(context)) if value is util.YES: yield util.YES @@ -198,24 +242,35 @@ def infer_subscript(self, context=None): yield util.YES return + index_value = _SLICE_SENTINEL if isinstance(index, nodes.Const): - try: - assigned = value.getitem(index.value, context) - except AttributeError: - raise exceptions.InferenceError() - except (IndexError, TypeError): - yield util.YES - return - - # Prevent inferring if the infered subscript - # is the same as the original subscripted object. - if self is assigned or assigned is util.YES: - yield util.YES - return - for infered in assigned.infer(context): - yield infered + index_value = index.value + elif isinstance(index, nodes.Slice): + # Infer slices from the original object. + lower = _slice_value(index.lower, context) + upper = _slice_value(index.upper, context) + step = _slice_value(index.step, context) + if all(elem is not _SLICE_SENTINEL for elem in (lower, upper, step)): + index_value = slice(lower, upper, step) else: raise exceptions.InferenceError() + + if index_value is _SLICE_SENTINEL: + raise exceptions.InferenceError + + try: + assigned = value.getitem(index_value, context) + except (IndexError, TypeError, AttributeError) as exc: + six.raise_from(exceptions.InferenceError, exc) + + # Prevent inferring if the infered subscript + # is the same as the original subscripted object. + if self is assigned or assigned is util.YES: + yield util.YES + return + for infered in assigned.infer(context): + yield infered + nodes.Subscript._infer = bases.path_wrapper(infer_subscript) nodes.Subscript.infer_lhs = bases.raise_if_nothing_infered(infer_subscript) @@ -638,7 +693,6 @@ nodes.Index._infer = infer_index def instance_getitem(self, index, context=None): # Rewrap index to Const for this case index = nodes.Const(index) - if context: new_context = context.clone() else: diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 944598a6..5ea755e2 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -109,6 +109,17 @@ def are_exclusive(stmt1, stmt2, exceptions=None): return False +def _container_getitem(instance, elts, index): + """Get a slice or an item, using the given *index*, for the given sequence.""" + if isinstance(index, slice): + new_cls = instance.__class__() + new_cls.elts = elts[index] + new_cls.parent = instance.parent + return new_cls + else: + return elts[index] + + @six.add_metaclass(abc.ABCMeta) class _BaseContainer(mixins.ParentAssignTypeMixin, bases.NodeNG, @@ -816,7 +827,7 @@ class List(_BaseContainer): return '%s.list' % BUILTINS def getitem(self, index, context=None): - return self.elts[index] + return _container_getitem(self, self.elts, index) class Nonlocal(bases.Statement): @@ -940,7 +951,7 @@ class Tuple(_BaseContainer): return '%s.tuple' % BUILTINS def getitem(self, index, context=None): - return self.elts[index] + return _container_getitem(self, self.elts, index) class UnaryOp(bases.NodeNG): diff --git a/astroid/objects.py b/astroid/objects.py index aa3848a3..c7d85d82 100644 --- a/astroid/objects.py +++ b/astroid/objects.py @@ -32,7 +32,6 @@ from astroid import bases from astroid import decorators from astroid import exceptions from astroid import MANAGER -from astroid import mixins from astroid import node_classes from astroid import scoped_nodes diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index 20666f58..02048947 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -699,13 +699,17 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): NoGetitem()[4] #@ InvalidGetitem()[5] #@ InvalidGetitem2()[10] #@ - [1, 2, 3][None] #@ - 'lala'['bala'] #@ ''') for node in ast_nodes[:3]: self.assertRaises(InferenceError, next, node.infer()) for node in ast_nodes[3:]: self.assertEqual(next(node.infer()), util.YES) + ast_nodes = test_utils.extract_node(''' + [1, 2, 3][None] #@ + 'lala'['bala'] #@ + ''') + for node in ast_nodes: + self.assertRaises(InferenceError, next, node.infer()) def test_bytes_subscript(self): node = test_utils.extract_node('''b'a'[0]''') @@ -2692,6 +2696,113 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): self.assertIsInstance(inferred, nodes.Const) self.assertEqual(inferred.value, 42) + def _slicing_test_helper(self, pairs, cls, get_elts): + for code, expected in pairs: + ast_node = test_utils.extract_node(code) + inferred = next(ast_node.infer()) + self.assertIsInstance(inferred, cls) + self.assertEqual(get_elts(inferred), expected, + ast_node.as_string()) + + def test_slicing_list(self): + pairs = ( + ("[1, 2, 3][:] #@", [1, 2, 3]), + ("[1, 2, 3][0:] #@", [1, 2, 3]), + ("[1, 2, 3][None:] #@", [1, 2, 3]), + ("[1, 2, 3][None:None] #@", [1, 2, 3]), + ("[1, 2, 3][0:-1] #@", [1, 2]), + ("[1, 2, 3][0:2] #@", [1, 2]), + ("[1, 2, 3][0:2:None] #@", [1, 2]), + ("[1, 2, 3][::] #@", [1, 2, 3]), + ("[1, 2, 3][::2] #@", [1, 3]), + ("[1, 2, 3][::-1] #@", [3, 2, 1]), + ("[1, 2, 3][0:2:2] #@", [1]), + ("[1, 2, 3, 4, 5, 6][0:4-1:2+0] #@", [1, 3]), + ) + self._slicing_test_helper( + pairs, nodes.List, + lambda inferred: [elt.value for elt in inferred.elts]) + + def test_slicing_tuple(self): + pairs = ( + ("(1, 2, 3)[:] #@", [1, 2, 3]), + ("(1, 2, 3)[0:] #@", [1, 2, 3]), + ("(1, 2, 3)[None:] #@", [1, 2, 3]), + ("(1, 2, 3)[None:None] #@", [1, 2, 3]), + ("(1, 2, 3)[0:-1] #@", [1, 2]), + ("(1, 2, 3)[0:2] #@", [1, 2]), + ("(1, 2, 3)[0:2:None] #@", [1, 2]), + ("(1, 2, 3)[::] #@", [1, 2, 3]), + ("(1, 2, 3)[::2] #@", [1, 3]), + ("(1, 2, 3)[::-1] #@", [3, 2, 1]), + ("(1, 2, 3)[0:2:2] #@", [1]), + ("(1, 2, 3, 4, 5, 6)[0:4-1:2+0] #@", [1, 3]), + ) + self._slicing_test_helper( + pairs, nodes.Tuple, + lambda inferred: [elt.value for elt in inferred.elts]) + + def test_slicing_str(self): + pairs = ( + ("'123'[:] #@", "123"), + ("'123'[0:] #@", "123"), + ("'123'[None:] #@", "123"), + ("'123'[None:None] #@", "123"), + ("'123'[0:-1] #@", "12"), + ("'123'[0:2] #@", "12"), + ("'123'[0:2:None] #@", "12"), + ("'123'[::] #@", "123"), + ("'123'[::2] #@", "13"), + ("'123'[::-1] #@", "321"), + ("'123'[0:2:2] #@", "1"), + ("'123456'[0:4-1:2+0] #@", "13"), + ) + self._slicing_test_helper( + pairs, nodes.Const, lambda inferred: inferred.value) + + def test_invalid_slicing_primaries(self): + examples = [ + "(lambda x: x)[1:2]", + "1[2]", + "enumerate[2]", + "(1, 2, 3)[a:]", + "(1, 2, 3)[object:object]", + "(1, 2, 3)[1:object]", + ] + for code in examples: + node = test_utils.extract_node(code) + self.assertRaises(InferenceError, next, node.infer()) + + def test_instance_slicing(self): + ast_nodes = test_utils.extract_node(''' + class A(object): + def __getitem__(self, index): + return [1, 2, 3, 4, 5][index] + A()[1:] #@ + A()[:2] #@ + A()[1:4] #@ + ''') + expected_values = [ + [2, 3, 4, 5], + [1, 2], + [2, 3, 4], + ] + for expected, node in zip(expected_values, ast_nodes): + inferred = next(node.infer()) + self.assertIsInstance(inferred, nodes.List) + self.assertEqual([elt.value for elt in inferred.elts], expected) + + def test_instance_slicing_fails(self): + ast_nodes = test_utils.extract_node(''' + class A(object): + def __getitem__(self, index): + return 1[index] + A()[4:5] #@ + A()[2:] #@ + ''') + for node in ast_nodes: + self.assertEqual(next(node.infer()), util.YES) + class GetattrTest(unittest.TestCase): |