diff options
author | Claudiu Popa <pcmanticore@gmail.com> | 2015-10-22 21:40:12 +0100 |
---|---|---|
committer | Claudiu Popa <pcmanticore@gmail.com> | 2015-10-22 21:40:12 +0100 |
commit | 4272e1326efcbca5b26eecd4d8fd43bcc8cf93f2 (patch) | |
tree | 6d100f62fb11a360674f281ee801ec3252ee5fda /astroid | |
parent | 6c250a70175892cdd6ad1e84d071197e1ba4e56b (diff) | |
download | astroid-4272e1326efcbca5b26eecd4d8fd43bcc8cf93f2.tar.gz |
Add support for indexing containers with instances which provides an __index__ returning-int method.
This patch moves _class_as_index to helpers, where it becames class_instance_as_index.
Also, it instantiates its own call context, which makes certain idioms with lambdas
to work.
Diffstat (limited to 'astroid')
-rw-r--r-- | astroid/helpers.py | 24 | ||||
-rw-r--r-- | astroid/inference.py | 4 | ||||
-rw-r--r-- | astroid/protocols.py | 22 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 23 |
4 files changed, 53 insertions, 20 deletions
diff --git a/astroid/helpers.py b/astroid/helpers.py index 2b31841..feb0985 100644 --- a/astroid/helpers.py +++ b/astroid/helpers.py @@ -26,6 +26,7 @@ from astroid import bases from astroid import context as contextmod
from astroid import exceptions
from astroid import manager
+from astroid import nodes
from astroid import raw_building
from astroid import scoped_nodes
from astroid import util
@@ -156,3 +157,26 @@ def is_subtype(type1, type2): def is_supertype(type1, type2):
"""Check if *type2* is a supertype of *type1*."""
return _type_check(type1, type2)
+
+
+def class_instance_as_index(node):
+ """Get the value as an index for the given instance.
+
+ If an instance provides an __index__ method, then it can
+ be used in some scenarios where an integer is expected,
+ for instance when multiplying or subscripting a list.
+ """
+ context = contextmod.InferenceContext()
+ context.callcontext = contextmod.CallContext(args=[node])
+
+ try:
+ for inferred in node.igetattr('__index__', context=context):
+ if not isinstance(inferred, bases.BoundMethod):
+ continue
+
+ for result in inferred.infer_call_result(node, context=context):
+ if (isinstance(result, nodes.Const)
+ and isinstance(result.value, int)):
+ return result
+ except exceptions.InferenceError:
+ pass
diff --git a/astroid/inference.py b/astroid/inference.py index b1c81da..3253acb 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -252,6 +252,10 @@ def infer_subscript(self, context=None): step = _slice_value(index.step, context) if all(elem is not _SLICE_SENTINEL for elem in (lower, upper, step)): index_value = slice(lower, upper, step) + elif isinstance(index, bases.Instance): + index = helpers.class_instance_as_index(index) + if index: + index_value = index.value else: raise exceptions.InferenceError() diff --git a/astroid/protocols.py b/astroid/protocols.py index 76e404d..23a4f2e 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -31,6 +31,7 @@ from astroid import context as contextmod from astroid import exceptions from astroid import decorators from astroid import node_classes +from astroid import helpers from astroid import nodes from astroid import util @@ -164,7 +165,7 @@ def tl_infer_binary_op(self, operator, other, context, method): yield _multiply_seq_by_int(self, other, context) elif isinstance(other, bases.Instance) and operator == '*': # Verify if the instance supports __index__. - as_index = _class_as_index(other, context) + as_index = helpers.class_instance_as_index(other) if not as_index: yield util.YES else: @@ -516,22 +517,3 @@ def starred_assigned_stmts(self, node=None, context=None, asspath=None): break nodes.Starred.assigned_stmts = starred_assigned_stmts - - -def _class_as_index(node, context): - """Get the value as an index for the given node - - It is expected that the node is an Instance. If it provides - an *__index__* method, we'll try to return its int value. - """ - try: - for inferred in node.igetattr('__index__', context=context): - if not isinstance(inferred, bases.BoundMethod): - continue - - for result in inferred.infer_call_result(node, context=context): - if (isinstance(result, nodes.Const) - and isinstance(result.value, int)): - return result - except exceptions.InferenceError: - pass diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index 02e9fcd..0e18bc5 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -2723,6 +2723,29 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): inferred = next(rest.infer()) self.assertEqual(inferred, util.YES) + def test_subscript_supports__index__(self): + ast_nodes = test_utils.extract_node(''' + class Index(object): + def __index__(self): return 2 + class LambdaIndex(object): + __index__ = lambda self: self.foo + @property + def foo(self): return 1 + class NonIndex(object): + __index__ = lambda self: None + a = [1, 2, 3, 4] + a[Index()] #@ + a[LambdaIndex()] #@ + a[NonIndex()] #@ + ''') + first = next(ast_nodes[0].infer()) + self.assertIsInstance(first, nodes.Const) + self.assertEqual(first.value, 3) + second = next(ast_nodes[1].infer()) + self.assertIsInstance(second, nodes.Const) + self.assertEqual(second.value, 2) + self.assertRaises(InferenceError, next, ast_nodes[2].infer()) + def test_special_method_masquerading_as_another(self): ast_node = test_utils.extract_node(''' class Info(object): |