summaryrefslogtreecommitdiff
path: root/astroid
diff options
context:
space:
mode:
authorClaudiu Popa <pcmanticore@gmail.com>2015-10-22 21:40:12 +0100
committerClaudiu Popa <pcmanticore@gmail.com>2015-10-22 21:40:12 +0100
commit4272e1326efcbca5b26eecd4d8fd43bcc8cf93f2 (patch)
tree6d100f62fb11a360674f281ee801ec3252ee5fda /astroid
parent6c250a70175892cdd6ad1e84d071197e1ba4e56b (diff)
downloadastroid-4272e1326efcbca5b26eecd4d8fd43bcc8cf93f2.tar.gz
Add support for indexing containers with instances which provides an __index__ returning-int method.
This patch moves _class_as_index to helpers, where it becames class_instance_as_index. Also, it instantiates its own call context, which makes certain idioms with lambdas to work.
Diffstat (limited to 'astroid')
-rw-r--r--astroid/helpers.py24
-rw-r--r--astroid/inference.py4
-rw-r--r--astroid/protocols.py22
-rw-r--r--astroid/tests/unittest_inference.py23
4 files changed, 53 insertions, 20 deletions
diff --git a/astroid/helpers.py b/astroid/helpers.py
index 2b31841..feb0985 100644
--- a/astroid/helpers.py
+++ b/astroid/helpers.py
@@ -26,6 +26,7 @@ from astroid import bases
from astroid import context as contextmod
from astroid import exceptions
from astroid import manager
+from astroid import nodes
from astroid import raw_building
from astroid import scoped_nodes
from astroid import util
@@ -156,3 +157,26 @@ def is_subtype(type1, type2):
def is_supertype(type1, type2):
"""Check if *type2* is a supertype of *type1*."""
return _type_check(type1, type2)
+
+
+def class_instance_as_index(node):
+ """Get the value as an index for the given instance.
+
+ If an instance provides an __index__ method, then it can
+ be used in some scenarios where an integer is expected,
+ for instance when multiplying or subscripting a list.
+ """
+ context = contextmod.InferenceContext()
+ context.callcontext = contextmod.CallContext(args=[node])
+
+ try:
+ for inferred in node.igetattr('__index__', context=context):
+ if not isinstance(inferred, bases.BoundMethod):
+ continue
+
+ for result in inferred.infer_call_result(node, context=context):
+ if (isinstance(result, nodes.Const)
+ and isinstance(result.value, int)):
+ return result
+ except exceptions.InferenceError:
+ pass
diff --git a/astroid/inference.py b/astroid/inference.py
index b1c81da..3253acb 100644
--- a/astroid/inference.py
+++ b/astroid/inference.py
@@ -252,6 +252,10 @@ def infer_subscript(self, context=None):
step = _slice_value(index.step, context)
if all(elem is not _SLICE_SENTINEL for elem in (lower, upper, step)):
index_value = slice(lower, upper, step)
+ elif isinstance(index, bases.Instance):
+ index = helpers.class_instance_as_index(index)
+ if index:
+ index_value = index.value
else:
raise exceptions.InferenceError()
diff --git a/astroid/protocols.py b/astroid/protocols.py
index 76e404d..23a4f2e 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -31,6 +31,7 @@ from astroid import context as contextmod
from astroid import exceptions
from astroid import decorators
from astroid import node_classes
+from astroid import helpers
from astroid import nodes
from astroid import util
@@ -164,7 +165,7 @@ def tl_infer_binary_op(self, operator, other, context, method):
yield _multiply_seq_by_int(self, other, context)
elif isinstance(other, bases.Instance) and operator == '*':
# Verify if the instance supports __index__.
- as_index = _class_as_index(other, context)
+ as_index = helpers.class_instance_as_index(other)
if not as_index:
yield util.YES
else:
@@ -516,22 +517,3 @@ def starred_assigned_stmts(self, node=None, context=None, asspath=None):
break
nodes.Starred.assigned_stmts = starred_assigned_stmts
-
-
-def _class_as_index(node, context):
- """Get the value as an index for the given node
-
- It is expected that the node is an Instance. If it provides
- an *__index__* method, we'll try to return its int value.
- """
- try:
- for inferred in node.igetattr('__index__', context=context):
- if not isinstance(inferred, bases.BoundMethod):
- continue
-
- for result in inferred.infer_call_result(node, context=context):
- if (isinstance(result, nodes.Const)
- and isinstance(result.value, int)):
- return result
- except exceptions.InferenceError:
- pass
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index 02e9fcd..0e18bc5 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -2723,6 +2723,29 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
inferred = next(rest.infer())
self.assertEqual(inferred, util.YES)
+ def test_subscript_supports__index__(self):
+ ast_nodes = test_utils.extract_node('''
+ class Index(object):
+ def __index__(self): return 2
+ class LambdaIndex(object):
+ __index__ = lambda self: self.foo
+ @property
+ def foo(self): return 1
+ class NonIndex(object):
+ __index__ = lambda self: None
+ a = [1, 2, 3, 4]
+ a[Index()] #@
+ a[LambdaIndex()] #@
+ a[NonIndex()] #@
+ ''')
+ first = next(ast_nodes[0].infer())
+ self.assertIsInstance(first, nodes.Const)
+ self.assertEqual(first.value, 3)
+ second = next(ast_nodes[1].infer())
+ self.assertIsInstance(second, nodes.Const)
+ self.assertEqual(second.value, 2)
+ self.assertRaises(InferenceError, next, ast_nodes[2].infer())
+
def test_special_method_masquerading_as_another(self):
ast_node = test_utils.extract_node('''
class Info(object):