summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bases.py55
-rw-r--r--inference.py46
-rw-r--r--protocols.py13
-rw-r--r--scoped_nodes.py45
-rw-r--r--test/unittest_nodes.py8
-rw-r--r--test/unittest_scoped_nodes.py3
6 files changed, 75 insertions, 95 deletions
diff --git a/bases.py b/bases.py
index bbdd3da3..60fd30d2 100644
--- a/bases.py
+++ b/bases.py
@@ -64,17 +64,16 @@ class InferenceContextPathContext(object):
Can't be a @contextmanager because it raises StopIteration.
"""
- def __init__(self, context, node):
+ def __init__(self, context, key):
self.original_path = context.path.copy()
self.context = context
- self.node = node
+ self.key = key
def __enter__(self):
- name = self.context.lookupname
- if (self.node, name) in self.context.path:
+ if self.key in self.context.path:
raise StopIteration
- self.context.path.add((self.node, name))
+ self.context.path.add(self.key)
return self
def __exit__(self, *exc_info):
@@ -82,33 +81,30 @@ class InferenceContextPathContext(object):
class InferenceContext(object):
- __slots__ = ('path', 'lookupname', 'callcontext', 'boundnode')
+ __slots__ = ('path', 'callcontext', 'boundnode')
def __init__(self, path=None):
if path is None:
self.path = set()
else:
self.path = path
- self.lookupname = None
self.callcontext = None
self.boundnode = None
- def push(self, node):
- return InferenceContextPathContext(self, node)
+ def push(self, key):
+ return InferenceContextPathContext(self, key)
@contextmanager
- def scope(self, lookupname=MISSING, callcontext=MISSING, boundnode=MISSING):
+ def scope(self, callcontext=MISSING, boundnode=MISSING):
try:
- orig = self.lookupname, self.callcontext, self.boundnode
- if lookupname is not MISSING:
- self.lookupname = lookupname
+ orig = self.callcontext, self.boundnode
if callcontext is not MISSING:
self.callcontext = callcontext
if boundnode is not MISSING:
self.boundnode = boundnode
yield
finally:
- self.lookupname, self.callcontext, self.boundnode = orig
+ self.callcontext, self.boundnode = orig
@contextmanager
def restore_path(self):
@@ -117,29 +113,34 @@ class InferenceContext(object):
self.path = path
-def _infer_stmts(stmts, context, frame=None):
+def _infer_stmts(stmts, context, frame=None, lookupname=None):
"""return an iterator on statements inferred by each statement in <stmts>
"""
stmt = None
infered = False
if context is None:
context = InferenceContext()
- name = context.lookupname
for stmt in stmts:
if stmt is YES:
yield stmt
infered = True
continue
- with context.scope(lookupname=stmt._infer_name(frame, name)):
- try:
- for infered in stmt.infer(context):
- yield infered
- infered = True
- except UnresolvableName:
- continue
- except InferenceError:
- yield YES
+
+ kw = {}
+ infered_name = stmt._infer_name(frame, lookupname)
+ if infered_name is not None:
+ # only returns not None if .infer() accepts a lookupname kwarg
+ kw['lookupname'] = infered_name
+
+ try:
+ for infered in stmt.infer(context, **kw):
+ yield infered
infered = True
+ except UnresolvableName:
+ continue
+ except InferenceError:
+ yield YES
+ infered = True
if not infered:
raise InferenceError(str(stmt))
@@ -297,7 +298,7 @@ class BoundMethod(UnboundMethod):
return True
def infer_call_result(self, caller, context):
- with context.scope(boundnode=self.bound, lookupname=None):
+ with context.scope(boundnode=self.bound):
for infered in self._proxied.infer_call_result(caller, context):
yield infered
@@ -331,7 +332,7 @@ def path_wrapper(func):
"""wrapper function handling context"""
if context is None:
context = InferenceContext()
- with context.push(node):
+ with context.push((node, kwargs.get('lookupname'))):
yielded = set()
for res in _func(node, context, **kwargs):
# unproxy only true instance, not const, tuple, dict...
diff --git a/inference.py b/inference.py
index 7bc2313b..6fcb02d3 100644
--- a/inference.py
+++ b/inference.py
@@ -146,9 +146,7 @@ def infer_name(self, context=None):
frame, stmts = self.lookup(self.name)
if not stmts:
raise UnresolvableName(self.name)
- with context.scope(lookupname=self.name):
- for infered in _infer_stmts(stmts, context, frame):
- yield infered
+ return _infer_stmts(stmts, context, frame, self.name)
nodes.Name._infer = path_wrapper(infer_name)
nodes.AssName.infer_lhs = infer_name # won't work with a path wrapper
@@ -161,7 +159,6 @@ def infer_callfunc(self, context=None):
with context.scope(
callcontext=CallContext(self.args, self.starargs, self.kwargs),
boundnode=None,
- lookupname=None,
):
if callee is YES:
yield callee
@@ -176,39 +173,33 @@ def infer_callfunc(self, context=None):
nodes.CallFunc._infer = path_wrapper(raise_if_nothing_infered(infer_callfunc))
-def infer_import(self, context=None, asname=True):
+def infer_import(self, context=None, asname=True, lookupname=None):
"""infer an Import node: return the imported module/object"""
- name = context.lookupname
- if name is None:
+ if lookupname is None:
raise InferenceError()
if asname:
- yield self.do_import_module(self.real_name(name))
+ yield self.do_import_module(self.real_name(lookupname))
else:
- yield self.do_import_module(name)
+ yield self.do_import_module(lookupname)
nodes.Import._infer = path_wrapper(infer_import)
def infer_name_module(self, name):
context = InferenceContext()
- with context.scope(lookupname=name):
- for infered in self.infer(context, asname=False):
- yield infered
+ return self.infer(context, asname=False, lookupname=name)
nodes.Import.infer_name_module = infer_name_module
-def infer_from(self, context=None, asname=True):
+def infer_from(self, context=None, asname=True, lookupname=None):
"""infer a From nodes: return the imported module/object"""
- name = context.lookupname
- if name is None:
+ if lookupname is None:
raise InferenceError()
if asname:
- name = self.real_name(name)
+ lookupname = self.real_name(lookupname)
module = self.do_import_module(self.modname)
try:
- with context.scope(lookupname=name):
- for infered in _infer_stmts(module.getattr(name, ignore_locals=module is self.root()), context):
- yield infered
+ return _infer_stmts(module.getattr(lookupname, ignore_locals=module is self.root()), context, lookupname=lookupname)
except NotFoundError:
- raise InferenceError(name)
+ raise InferenceError(lookupname)
nodes.From._infer = path_wrapper(infer_from)
@@ -221,7 +212,7 @@ def infer_getattr(self, context=None):
yield owner
continue
try:
- with context.scope(boundnode=owner, lookupname=None):
+ with context.scope(boundnode=owner):
for obj in owner.igetattr(self.attrname, context):
yield obj
except (NotFoundError, InferenceError):
@@ -233,11 +224,11 @@ nodes.Getattr._infer = path_wrapper(raise_if_nothing_infered(infer_getattr))
nodes.AssAttr.infer_lhs = raise_if_nothing_infered(infer_getattr) # # won't work with a path wrapper
-def infer_global(self, context=None):
- if context.lookupname is None:
+def infer_global(self, context=None, lookupname=None):
+ if lookupname is None:
raise InferenceError()
try:
- return _infer_stmts(self.root().getattr(context.lookupname), context)
+ return _infer_stmts(self.root().getattr(lookupname), context)
except NotFoundError:
raise InferenceError()
nodes.Global._infer = path_wrapper(infer_global)
@@ -345,11 +336,10 @@ def infer_binop(self, context=None):
nodes.BinOp._infer = path_wrapper(infer_binop)
-def infer_arguments(self, context=None):
- name = context.lookupname
- if name is None:
+def infer_arguments(self, context=None, lookupname=None):
+ if lookupname is None:
raise InferenceError()
- return _arguments_infer_argname(self, name, context)
+ return _arguments_infer_argname(self, lookupname, context)
nodes.Arguments._infer = infer_arguments
diff --git a/protocols.py b/protocols.py
index 616340c9..f486e7cc 100644
--- a/protocols.py
+++ b/protocols.py
@@ -241,9 +241,8 @@ def _arguments_infer_argname(self, name, context):
try:
if context is None:
context = InferenceContext()
- with context.scope(lookupname=None):
- for infered in self.default_value(name).infer(context):
- yield infered
+ for infered in self.default_value(name).infer(context):
+ yield infered
yield YES
except NoDefault:
yield YES
@@ -253,12 +252,8 @@ def arguments_assigned_stmts(self, node, context, asspath=None):
if context.callcontext:
# reset call context/name
callcontext = context.callcontext
- with context.scope(callcontext=None, lookupname=None):
- for infered in callcontext.infer_argument(self.parent, node.name, context):
- yield infered
- return
- for infered in _arguments_infer_argname(self, node.name, context):
- yield infered
+ return callcontext.infer_argument(self.parent, node.name, context)
+ return _arguments_infer_argname(self, node.name, context)
nodes.Arguments.assigned_stmts = arguments_assigned_stmts
diff --git a/scoped_nodes.py b/scoped_nodes.py
index a5bc37e9..389ebe71 100644
--- a/scoped_nodes.py
+++ b/scoped_nodes.py
@@ -310,12 +310,10 @@ class Module(LocalsDictNodeNG):
# instance
if not context:
context = InferenceContext()
- with context.scope(lookupname=name):
- try:
- for infered in _infer_stmts(self.getattr(name, context), context, frame=self):
- yield infered
- except NotFoundError:
- raise InferenceError(name)
+ try:
+ return _infer_stmts(self.getattr(name, context), context, frame=self, lookupname=name)
+ except NotFoundError:
+ raise InferenceError(name)
def fully_defined(self):
"""return True if this module has been built from a .py file
@@ -997,26 +995,25 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin):
# instance
if not context:
context = InferenceContext()
- with context.scope(lookupname=name):
- try:
- for infered in _infer_stmts(self.getattr(name, context), context,
- frame=self):
- # yield YES object instead of descriptors when necessary
- if not isinstance(infered, Const) and isinstance(infered, Instance):
- try:
- infered._proxied.getattr('__get__', context)
- except NotFoundError:
- yield infered
- else:
- yield YES
+ try:
+ for infered in _infer_stmts(self.getattr(name, context), context,
+ frame=self, lookupname=name):
+ # yield YES object instead of descriptors when necessary
+ if not isinstance(infered, Const) and isinstance(infered, Instance):
+ try:
+ infered._proxied.getattr('__get__', context)
+ except NotFoundError:
+ yield infered
else:
- yield function_to_method(infered, self)
- except NotFoundError:
- if not name.startswith('__') and self.has_dynamic_getattr(context):
- # class handle some dynamic attributes, return a YES object
- yield YES
+ yield YES
else:
- raise InferenceError(name)
+ yield function_to_method(infered, self)
+ except NotFoundError:
+ if not name.startswith('__') and self.has_dynamic_getattr(context):
+ # class handle some dynamic attributes, return a YES object
+ yield YES
+ else:
+ raise InferenceError(name)
def has_dynamic_getattr(self, context=None):
"""return True if the class has a custom __getattr__ or
diff --git a/test/unittest_nodes.py b/test/unittest_nodes.py
index b5245caa..37300176 100644
--- a/test/unittest_nodes.py
+++ b/test/unittest_nodes.py
@@ -308,11 +308,9 @@ except PickleError:
def test_absolute_import(self):
astroid = abuilder.file_build(self.datapath('absimport.py'))
ctx = InferenceContext()
- with ctx.scope(lookupname='message'):
- # will fail if absolute import failed
- astroid['message'].infer(ctx).next()
- with ctx.scope(lookupname='email'):
- m = astroid['email'].infer(ctx).next()
+ # will fail if absolute import failed
+ astroid['message'].infer(ctx, lookupname='message').next()
+ m = astroid['email'].infer(ctx, lookupname='email').next()
self.assertFalse(m.file.startswith(self.datapath('email.py')))
diff --git a/test/unittest_scoped_nodes.py b/test/unittest_scoped_nodes.py
index de0b395f..b6f3434b 100644
--- a/test/unittest_scoped_nodes.py
+++ b/test/unittest_scoped_nodes.py
@@ -96,8 +96,7 @@ class ModuleNodeTC(TestCase):
del sys.path[1]
self.assertEqual(len(NONREGR.getattr('enumerate')), 2)
# raise ResolveError
- gen = MODULE.igetattr('YOAA')
- self.assertRaises(InferenceError, list, gen)
+ self.assertRaises(InferenceError, MODULE.igetattr, 'YOAA')
def test_wildard_import_names(self):
m = abuilder.file_build(join(DATA, 'all.py'), 'all')