diff options
-rw-r--r-- | astroid/brain/brain_namedtuple_enum.py | 35 | ||||
-rw-r--r-- | astroid/tests/unittest_brain.py | 12 |
2 files changed, 34 insertions, 13 deletions
diff --git a/astroid/brain/brain_namedtuple_enum.py b/astroid/brain/brain_namedtuple_enum.py index 56cfb247..519a75f4 100644 --- a/astroid/brain/brain_namedtuple_enum.py +++ b/astroid/brain/brain_namedtuple_enum.py @@ -39,22 +39,31 @@ def _infer_first(node, context): def _find_func_form_arguments(node, context): + + def _extract_namedtuple_arg_or_keyword(position, key_name=None): + + if len(args) > position: + return _infer_first(args[position], context) + if key_name and key_name in found_keywords: + return _infer_first(found_keywords[key_name], context) + args = node.args keywords = node.keywords + found_keywords = { + keyword.arg: keyword.value for keyword in keywords + } if keywords else {} + + name = _extract_namedtuple_arg_or_keyword( + position=0, + key_name='typename' + ) + names = _extract_namedtuple_arg_or_keyword( + position=1, + key_name='field_names' + ) + if name and names: + return name.value, names - if args and len(args) == 2: - name = _infer_first(node.args[0], context).value - names = _infer_first(node.args[1], context) - - return name, names - elif keywords: - found_keywords = { - keyword.arg: keyword.value for keyword in keywords - } - if {'field_names', 'typename'}.issubset(found_keywords.keys()): - name = _infer_first(found_keywords['typename'], context).value - names = _infer_first(found_keywords['field_names'], context) - return name, names raise UseInferenceDefault() diff --git a/astroid/tests/unittest_brain.py b/astroid/tests/unittest_brain.py index 72fc139a..0fb5e93f 100644 --- a/astroid/tests/unittest_brain.py +++ b/astroid/tests/unittest_brain.py @@ -233,6 +233,18 @@ class NamedTupleTest(unittest.TestCase): self.assertIn('b', inferred.locals) self.assertIn('c', inferred.locals) + def test_namedtuple_func_form_args_and_kwargs(self): + node = builder.extract_node(""" + from collections import namedtuple + Tuple = namedtuple("Tuple", field_names="a b c", rename=UNINFERABLE) + Tuple #@ + """) + inferred = next(node.infer()) + self.assertEqual(inferred.name, 'Tuple') + self.assertIn('a', inferred.locals) + self.assertIn('b', inferred.locals) + self.assertIn('c', inferred.locals) + class DefaultDictTest(unittest.TestCase): |