diff options
-rw-r--r-- | ChangeLog | 9 | ||||
-rw-r--r-- | astroid/brain/brain_builtin_inference.py | 34 | ||||
-rw-r--r-- | astroid/node_classes.py | 15 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 17 |
4 files changed, 57 insertions, 18 deletions
@@ -6,6 +6,15 @@ What's New in astroid 2.3.0? ============================ Release Date: TBA +* Improved builtin inference for ``tuple``, ``set``, ``frozenset``, ``list`` and ``dict`` + + We were properly inferring these callables *only* if they had consts as + values, but that is not the case most of the time. Instead we try to infer + the values that their arguments can be and use them instead of assuming + Const nodes all the time. + + Close PyCQA/pylint#2841 + * The last except handler wins when inferring variables bound in an except handler. Close PyCQA/pylint#2777 diff --git a/astroid/brain/brain_builtin_inference.py b/astroid/brain/brain_builtin_inference.py index 53ca2f98..c9758c79 100644 --- a/astroid/brain/brain_builtin_inference.py +++ b/astroid/brain/brain_builtin_inference.py @@ -147,7 +147,7 @@ def register_builtin_transform(transform, builtin_name): ) -def _generic_inference(node, context, node_type, transform): +def _container_generic_inference(node, context, node_type, transform): args = node.args if not args: return node_type() @@ -169,14 +169,17 @@ def _generic_inference(node, context, node_type, transform): return transformed -def _generic_transform(arg, klass, iterables, build_elts): +def _container_generic_transform(arg, klass, iterables, build_elts): if isinstance(arg, klass): return arg elif isinstance(arg, iterables): - if not all(isinstance(elt, nodes.Const) for elt in arg.elts): - raise UseInferenceDefault() - elts = [elt.value for elt in arg.elts] + if all(isinstance(elt, nodes.Const) for elt in arg.elts): + elts = [elt.value for elt in arg.elts] + else: + # TODO: Does not handle deduplication for sets. + elts = filter(None, map(helpers.safe_infer, arg.elts)) elif isinstance(arg, nodes.Dict): + # Dicts need to have consts as strings already. if not all(isinstance(elt[0], nodes.Const) for elt in arg.items): raise UseInferenceDefault() elts = [item[0].value for item in arg.items] @@ -186,20 +189,25 @@ def _generic_transform(arg, klass, iterables, build_elts): elts = arg.value else: return - return klass.from_constants(elts=build_elts(elts)) + return klass.from_elements(elts=build_elts(elts)) -def _infer_builtin(node, context, klass=None, iterables=None, build_elts=None): +def _infer_builtin_container( + node, context, klass=None, iterables=None, build_elts=None +): transform_func = partial( - _generic_transform, klass=klass, iterables=iterables, build_elts=build_elts + _container_generic_transform, + klass=klass, + iterables=iterables, + build_elts=build_elts, ) - return _generic_inference(node, context, klass, transform_func) + return _container_generic_inference(node, context, klass, transform_func) # pylint: disable=invalid-name infer_tuple = partial( - _infer_builtin, + _infer_builtin_container, klass=nodes.Tuple, iterables=( nodes.List, @@ -213,7 +221,7 @@ infer_tuple = partial( ) infer_list = partial( - _infer_builtin, + _infer_builtin_container, klass=nodes.List, iterables=( nodes.Tuple, @@ -227,14 +235,14 @@ infer_list = partial( ) infer_set = partial( - _infer_builtin, + _infer_builtin_container, klass=nodes.Set, iterables=(nodes.List, nodes.Tuple, objects.FrozenSet, objects.DictKeys), build_elts=set, ) infer_frozenset = partial( - _infer_builtin, + _infer_builtin_container, klass=objects.FrozenSet, iterables=(nodes.List, nodes.Tuple, nodes.Set, objects.FrozenSet, objects.DictKeys), build_elts=frozenset, diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 204d8023..449470d0 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -48,6 +48,10 @@ BUILTINS = builtins_mod.__name__ MANAGER = manager.AstroidManager() +def _is_const(value): + return isinstance(value, tuple(CONST_CLS)) + + @decorators.raise_if_nothing_inferred def unpack_infer(stmt, context=None): """recursively generate nodes inferred by the given statement. @@ -1017,7 +1021,7 @@ class _BaseContainer( self.elts = elts @classmethod - def from_constants(cls, elts=None): + def from_elements(cls, elts=None): """Create a node of this type from the given list of elements. :param elts: The list of elements that the node should contain. @@ -1030,7 +1034,7 @@ class _BaseContainer( if elts is None: node.elts = [] else: - node.elts = [const_factory(e) for e in elts] + node.elts = [const_factory(e) if _is_const(e) else e for e in elts] return node def itered(self): @@ -2728,7 +2732,7 @@ class Dict(NodeNG, bases.Instance): self.items = items @classmethod - def from_constants(cls, items=None): + def from_elements(cls, items=None): """Create a :class:`Dict` of constants from a live dictionary. :param items: The items to store in the node. @@ -2742,7 +2746,10 @@ class Dict(NodeNG, bases.Instance): node.items = [] else: node.items = [ - (const_factory(k), const_factory(v)) for k, v in items.items() + (const_factory(k), const_factory(v) if _is_const(v) else v) + for k, v in items.items() + # The keys need to be constants + if _is_const(k) ] return node diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index c8e8885c..2cda59da 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -1040,7 +1040,7 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): def test_binary_op_float_div(self): ast = builder.string_build("a = 1 / 2.", __name__, __file__) - self._test_const_inferred(ast["a"], 1 / 2.) + self._test_const_inferred(ast["a"], 1 / 2.0) def test_binary_op_str_mul(self): ast = builder.string_build('a = "*" * 40', __name__, __file__) @@ -5149,5 +5149,20 @@ def test_exception_lookup_name_bound_in_except_handler(): assert inferred_exc.value == 2 +def test_builtin_inference_list_of_exceptions(): + node = extract_node( + """ + tuple([ValueError, TypeError]) + """ + ) + inferred = next(node.infer()) + assert isinstance(inferred, nodes.Tuple) + assert len(inferred.elts) == 2 + assert isinstance(inferred.elts[0], nodes.ClassDef) + assert inferred.elts[0].name == "ValueError" + assert isinstance(inferred.elts[1], nodes.ClassDef) + assert inferred.elts[1].name == "TypeError" + + if __name__ == "__main__": unittest.main() |