summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog9
-rw-r--r--astroid/brain/brain_builtin_inference.py34
-rw-r--r--astroid/node_classes.py15
-rw-r--r--astroid/tests/unittest_inference.py17
4 files changed, 57 insertions, 18 deletions
diff --git a/ChangeLog b/ChangeLog
index c1e75d74..02156112 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -6,6 +6,15 @@ What's New in astroid 2.3.0?
============================
Release Date: TBA
+* Improved builtin inference for ``tuple``, ``set``, ``frozenset``, ``list`` and ``dict``
+
+ We were properly inferring these callables *only* if they had consts as
+ values, but that is not the case most of the time. Instead we try to infer
+ the values that their arguments can be and use them instead of assuming
+ Const nodes all the time.
+
+ Close PyCQA/pylint#2841
+
* The last except handler wins when inferring variables bound in an except handler.
Close PyCQA/pylint#2777
diff --git a/astroid/brain/brain_builtin_inference.py b/astroid/brain/brain_builtin_inference.py
index 53ca2f98..c9758c79 100644
--- a/astroid/brain/brain_builtin_inference.py
+++ b/astroid/brain/brain_builtin_inference.py
@@ -147,7 +147,7 @@ def register_builtin_transform(transform, builtin_name):
)
-def _generic_inference(node, context, node_type, transform):
+def _container_generic_inference(node, context, node_type, transform):
args = node.args
if not args:
return node_type()
@@ -169,14 +169,17 @@ def _generic_inference(node, context, node_type, transform):
return transformed
-def _generic_transform(arg, klass, iterables, build_elts):
+def _container_generic_transform(arg, klass, iterables, build_elts):
if isinstance(arg, klass):
return arg
elif isinstance(arg, iterables):
- if not all(isinstance(elt, nodes.Const) for elt in arg.elts):
- raise UseInferenceDefault()
- elts = [elt.value for elt in arg.elts]
+ if all(isinstance(elt, nodes.Const) for elt in arg.elts):
+ elts = [elt.value for elt in arg.elts]
+ else:
+ # TODO: Does not handle deduplication for sets.
+ elts = filter(None, map(helpers.safe_infer, arg.elts))
elif isinstance(arg, nodes.Dict):
+ # Dicts need to have consts as strings already.
if not all(isinstance(elt[0], nodes.Const) for elt in arg.items):
raise UseInferenceDefault()
elts = [item[0].value for item in arg.items]
@@ -186,20 +189,25 @@ def _generic_transform(arg, klass, iterables, build_elts):
elts = arg.value
else:
return
- return klass.from_constants(elts=build_elts(elts))
+ return klass.from_elements(elts=build_elts(elts))
-def _infer_builtin(node, context, klass=None, iterables=None, build_elts=None):
+def _infer_builtin_container(
+ node, context, klass=None, iterables=None, build_elts=None
+):
transform_func = partial(
- _generic_transform, klass=klass, iterables=iterables, build_elts=build_elts
+ _container_generic_transform,
+ klass=klass,
+ iterables=iterables,
+ build_elts=build_elts,
)
- return _generic_inference(node, context, klass, transform_func)
+ return _container_generic_inference(node, context, klass, transform_func)
# pylint: disable=invalid-name
infer_tuple = partial(
- _infer_builtin,
+ _infer_builtin_container,
klass=nodes.Tuple,
iterables=(
nodes.List,
@@ -213,7 +221,7 @@ infer_tuple = partial(
)
infer_list = partial(
- _infer_builtin,
+ _infer_builtin_container,
klass=nodes.List,
iterables=(
nodes.Tuple,
@@ -227,14 +235,14 @@ infer_list = partial(
)
infer_set = partial(
- _infer_builtin,
+ _infer_builtin_container,
klass=nodes.Set,
iterables=(nodes.List, nodes.Tuple, objects.FrozenSet, objects.DictKeys),
build_elts=set,
)
infer_frozenset = partial(
- _infer_builtin,
+ _infer_builtin_container,
klass=objects.FrozenSet,
iterables=(nodes.List, nodes.Tuple, nodes.Set, objects.FrozenSet, objects.DictKeys),
build_elts=frozenset,
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index 204d8023..449470d0 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -48,6 +48,10 @@ BUILTINS = builtins_mod.__name__
MANAGER = manager.AstroidManager()
+def _is_const(value):
+ return isinstance(value, tuple(CONST_CLS))
+
+
@decorators.raise_if_nothing_inferred
def unpack_infer(stmt, context=None):
"""recursively generate nodes inferred by the given statement.
@@ -1017,7 +1021,7 @@ class _BaseContainer(
self.elts = elts
@classmethod
- def from_constants(cls, elts=None):
+ def from_elements(cls, elts=None):
"""Create a node of this type from the given list of elements.
:param elts: The list of elements that the node should contain.
@@ -1030,7 +1034,7 @@ class _BaseContainer(
if elts is None:
node.elts = []
else:
- node.elts = [const_factory(e) for e in elts]
+ node.elts = [const_factory(e) if _is_const(e) else e for e in elts]
return node
def itered(self):
@@ -2728,7 +2732,7 @@ class Dict(NodeNG, bases.Instance):
self.items = items
@classmethod
- def from_constants(cls, items=None):
+ def from_elements(cls, items=None):
"""Create a :class:`Dict` of constants from a live dictionary.
:param items: The items to store in the node.
@@ -2742,7 +2746,10 @@ class Dict(NodeNG, bases.Instance):
node.items = []
else:
node.items = [
- (const_factory(k), const_factory(v)) for k, v in items.items()
+ (const_factory(k), const_factory(v) if _is_const(v) else v)
+ for k, v in items.items()
+ # The keys need to be constants
+ if _is_const(k)
]
return node
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index c8e8885c..2cda59da 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -1040,7 +1040,7 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
def test_binary_op_float_div(self):
ast = builder.string_build("a = 1 / 2.", __name__, __file__)
- self._test_const_inferred(ast["a"], 1 / 2.)
+ self._test_const_inferred(ast["a"], 1 / 2.0)
def test_binary_op_str_mul(self):
ast = builder.string_build('a = "*" * 40', __name__, __file__)
@@ -5149,5 +5149,20 @@ def test_exception_lookup_name_bound_in_except_handler():
assert inferred_exc.value == 2
+def test_builtin_inference_list_of_exceptions():
+ node = extract_node(
+ """
+ tuple([ValueError, TypeError])
+ """
+ )
+ inferred = next(node.infer())
+ assert isinstance(inferred, nodes.Tuple)
+ assert len(inferred.elts) == 2
+ assert isinstance(inferred.elts[0], nodes.ClassDef)
+ assert inferred.elts[0].name == "ValueError"
+ assert isinstance(inferred.elts[1], nodes.ClassDef)
+ assert inferred.elts[1].name == "TypeError"
+
+
if __name__ == "__main__":
unittest.main()