diff options
-rw-r--r-- | bases.py | 55 | ||||
-rw-r--r-- | inference.py | 46 | ||||
-rw-r--r-- | protocols.py | 13 | ||||
-rw-r--r-- | scoped_nodes.py | 45 | ||||
-rw-r--r-- | test/unittest_nodes.py | 8 | ||||
-rw-r--r-- | test/unittest_scoped_nodes.py | 3 |
6 files changed, 75 insertions, 95 deletions
@@ -64,17 +64,16 @@ class InferenceContextPathContext(object): Can't be a @contextmanager because it raises StopIteration. """ - def __init__(self, context, node): + def __init__(self, context, key): self.original_path = context.path.copy() self.context = context - self.node = node + self.key = key def __enter__(self): - name = self.context.lookupname - if (self.node, name) in self.context.path: + if self.key in self.context.path: raise StopIteration - self.context.path.add((self.node, name)) + self.context.path.add(self.key) return self def __exit__(self, *exc_info): @@ -82,33 +81,30 @@ class InferenceContextPathContext(object): class InferenceContext(object): - __slots__ = ('path', 'lookupname', 'callcontext', 'boundnode') + __slots__ = ('path', 'callcontext', 'boundnode') def __init__(self, path=None): if path is None: self.path = set() else: self.path = path - self.lookupname = None self.callcontext = None self.boundnode = None - def push(self, node): - return InferenceContextPathContext(self, node) + def push(self, key): + return InferenceContextPathContext(self, key) @contextmanager - def scope(self, lookupname=MISSING, callcontext=MISSING, boundnode=MISSING): + def scope(self, callcontext=MISSING, boundnode=MISSING): try: - orig = self.lookupname, self.callcontext, self.boundnode - if lookupname is not MISSING: - self.lookupname = lookupname + orig = self.callcontext, self.boundnode if callcontext is not MISSING: self.callcontext = callcontext if boundnode is not MISSING: self.boundnode = boundnode yield finally: - self.lookupname, self.callcontext, self.boundnode = orig + self.callcontext, self.boundnode = orig @contextmanager def restore_path(self): @@ -117,29 +113,34 @@ class InferenceContext(object): self.path = path -def _infer_stmts(stmts, context, frame=None): +def _infer_stmts(stmts, context, frame=None, lookupname=None): """return an iterator on statements inferred by each statement in <stmts> """ stmt = None infered = False if context is None: context = InferenceContext() - name = context.lookupname for stmt in stmts: if stmt is YES: yield stmt infered = True continue - with context.scope(lookupname=stmt._infer_name(frame, name)): - try: - for infered in stmt.infer(context): - yield infered - infered = True - except UnresolvableName: - continue - except InferenceError: - yield YES + + kw = {} + infered_name = stmt._infer_name(frame, lookupname) + if infered_name is not None: + # only returns not None if .infer() accepts a lookupname kwarg + kw['lookupname'] = infered_name + + try: + for infered in stmt.infer(context, **kw): + yield infered infered = True + except UnresolvableName: + continue + except InferenceError: + yield YES + infered = True if not infered: raise InferenceError(str(stmt)) @@ -297,7 +298,7 @@ class BoundMethod(UnboundMethod): return True def infer_call_result(self, caller, context): - with context.scope(boundnode=self.bound, lookupname=None): + with context.scope(boundnode=self.bound): for infered in self._proxied.infer_call_result(caller, context): yield infered @@ -331,7 +332,7 @@ def path_wrapper(func): """wrapper function handling context""" if context is None: context = InferenceContext() - with context.push(node): + with context.push((node, kwargs.get('lookupname'))): yielded = set() for res in _func(node, context, **kwargs): # unproxy only true instance, not const, tuple, dict... diff --git a/inference.py b/inference.py index 7bc2313b..6fcb02d3 100644 --- a/inference.py +++ b/inference.py @@ -146,9 +146,7 @@ def infer_name(self, context=None): frame, stmts = self.lookup(self.name) if not stmts: raise UnresolvableName(self.name) - with context.scope(lookupname=self.name): - for infered in _infer_stmts(stmts, context, frame): - yield infered + return _infer_stmts(stmts, context, frame, self.name) nodes.Name._infer = path_wrapper(infer_name) nodes.AssName.infer_lhs = infer_name # won't work with a path wrapper @@ -161,7 +159,6 @@ def infer_callfunc(self, context=None): with context.scope( callcontext=CallContext(self.args, self.starargs, self.kwargs), boundnode=None, - lookupname=None, ): if callee is YES: yield callee @@ -176,39 +173,33 @@ def infer_callfunc(self, context=None): nodes.CallFunc._infer = path_wrapper(raise_if_nothing_infered(infer_callfunc)) -def infer_import(self, context=None, asname=True): +def infer_import(self, context=None, asname=True, lookupname=None): """infer an Import node: return the imported module/object""" - name = context.lookupname - if name is None: + if lookupname is None: raise InferenceError() if asname: - yield self.do_import_module(self.real_name(name)) + yield self.do_import_module(self.real_name(lookupname)) else: - yield self.do_import_module(name) + yield self.do_import_module(lookupname) nodes.Import._infer = path_wrapper(infer_import) def infer_name_module(self, name): context = InferenceContext() - with context.scope(lookupname=name): - for infered in self.infer(context, asname=False): - yield infered + return self.infer(context, asname=False, lookupname=name) nodes.Import.infer_name_module = infer_name_module -def infer_from(self, context=None, asname=True): +def infer_from(self, context=None, asname=True, lookupname=None): """infer a From nodes: return the imported module/object""" - name = context.lookupname - if name is None: + if lookupname is None: raise InferenceError() if asname: - name = self.real_name(name) + lookupname = self.real_name(lookupname) module = self.do_import_module(self.modname) try: - with context.scope(lookupname=name): - for infered in _infer_stmts(module.getattr(name, ignore_locals=module is self.root()), context): - yield infered + return _infer_stmts(module.getattr(lookupname, ignore_locals=module is self.root()), context, lookupname=lookupname) except NotFoundError: - raise InferenceError(name) + raise InferenceError(lookupname) nodes.From._infer = path_wrapper(infer_from) @@ -221,7 +212,7 @@ def infer_getattr(self, context=None): yield owner continue try: - with context.scope(boundnode=owner, lookupname=None): + with context.scope(boundnode=owner): for obj in owner.igetattr(self.attrname, context): yield obj except (NotFoundError, InferenceError): @@ -233,11 +224,11 @@ nodes.Getattr._infer = path_wrapper(raise_if_nothing_infered(infer_getattr)) nodes.AssAttr.infer_lhs = raise_if_nothing_infered(infer_getattr) # # won't work with a path wrapper -def infer_global(self, context=None): - if context.lookupname is None: +def infer_global(self, context=None, lookupname=None): + if lookupname is None: raise InferenceError() try: - return _infer_stmts(self.root().getattr(context.lookupname), context) + return _infer_stmts(self.root().getattr(lookupname), context) except NotFoundError: raise InferenceError() nodes.Global._infer = path_wrapper(infer_global) @@ -345,11 +336,10 @@ def infer_binop(self, context=None): nodes.BinOp._infer = path_wrapper(infer_binop) -def infer_arguments(self, context=None): - name = context.lookupname - if name is None: +def infer_arguments(self, context=None, lookupname=None): + if lookupname is None: raise InferenceError() - return _arguments_infer_argname(self, name, context) + return _arguments_infer_argname(self, lookupname, context) nodes.Arguments._infer = infer_arguments diff --git a/protocols.py b/protocols.py index 616340c9..f486e7cc 100644 --- a/protocols.py +++ b/protocols.py @@ -241,9 +241,8 @@ def _arguments_infer_argname(self, name, context): try: if context is None: context = InferenceContext() - with context.scope(lookupname=None): - for infered in self.default_value(name).infer(context): - yield infered + for infered in self.default_value(name).infer(context): + yield infered yield YES except NoDefault: yield YES @@ -253,12 +252,8 @@ def arguments_assigned_stmts(self, node, context, asspath=None): if context.callcontext: # reset call context/name callcontext = context.callcontext - with context.scope(callcontext=None, lookupname=None): - for infered in callcontext.infer_argument(self.parent, node.name, context): - yield infered - return - for infered in _arguments_infer_argname(self, node.name, context): - yield infered + return callcontext.infer_argument(self.parent, node.name, context) + return _arguments_infer_argname(self, node.name, context) nodes.Arguments.assigned_stmts = arguments_assigned_stmts diff --git a/scoped_nodes.py b/scoped_nodes.py index a5bc37e9..389ebe71 100644 --- a/scoped_nodes.py +++ b/scoped_nodes.py @@ -310,12 +310,10 @@ class Module(LocalsDictNodeNG): # instance if not context: context = InferenceContext() - with context.scope(lookupname=name): - try: - for infered in _infer_stmts(self.getattr(name, context), context, frame=self): - yield infered - except NotFoundError: - raise InferenceError(name) + try: + return _infer_stmts(self.getattr(name, context), context, frame=self, lookupname=name) + except NotFoundError: + raise InferenceError(name) def fully_defined(self): """return True if this module has been built from a .py file @@ -997,26 +995,25 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): # instance if not context: context = InferenceContext() - with context.scope(lookupname=name): - try: - for infered in _infer_stmts(self.getattr(name, context), context, - frame=self): - # yield YES object instead of descriptors when necessary - if not isinstance(infered, Const) and isinstance(infered, Instance): - try: - infered._proxied.getattr('__get__', context) - except NotFoundError: - yield infered - else: - yield YES + try: + for infered in _infer_stmts(self.getattr(name, context), context, + frame=self, lookupname=name): + # yield YES object instead of descriptors when necessary + if not isinstance(infered, Const) and isinstance(infered, Instance): + try: + infered._proxied.getattr('__get__', context) + except NotFoundError: + yield infered else: - yield function_to_method(infered, self) - except NotFoundError: - if not name.startswith('__') and self.has_dynamic_getattr(context): - # class handle some dynamic attributes, return a YES object - yield YES + yield YES else: - raise InferenceError(name) + yield function_to_method(infered, self) + except NotFoundError: + if not name.startswith('__') and self.has_dynamic_getattr(context): + # class handle some dynamic attributes, return a YES object + yield YES + else: + raise InferenceError(name) def has_dynamic_getattr(self, context=None): """return True if the class has a custom __getattr__ or diff --git a/test/unittest_nodes.py b/test/unittest_nodes.py index b5245caa..37300176 100644 --- a/test/unittest_nodes.py +++ b/test/unittest_nodes.py @@ -308,11 +308,9 @@ except PickleError: def test_absolute_import(self): astroid = abuilder.file_build(self.datapath('absimport.py')) ctx = InferenceContext() - with ctx.scope(lookupname='message'): - # will fail if absolute import failed - astroid['message'].infer(ctx).next() - with ctx.scope(lookupname='email'): - m = astroid['email'].infer(ctx).next() + # will fail if absolute import failed + astroid['message'].infer(ctx, lookupname='message').next() + m = astroid['email'].infer(ctx, lookupname='email').next() self.assertFalse(m.file.startswith(self.datapath('email.py'))) diff --git a/test/unittest_scoped_nodes.py b/test/unittest_scoped_nodes.py index de0b395f..b6f3434b 100644 --- a/test/unittest_scoped_nodes.py +++ b/test/unittest_scoped_nodes.py @@ -96,8 +96,7 @@ class ModuleNodeTC(TestCase): del sys.path[1] self.assertEqual(len(NONREGR.getattr('enumerate')), 2) # raise ResolveError - gen = MODULE.igetattr('YOAA') - self.assertRaises(InferenceError, list, gen) + self.assertRaises(InferenceError, MODULE.igetattr, 'YOAA') def test_wildard_import_names(self): m = abuilder.file_build(join(DATA, 'all.py'), 'all') |