diff options
Diffstat (limited to 'astroid/brain/brain_dataclasses.py')
-rw-r--r-- | astroid/brain/brain_dataclasses.py | 69 |
1 files changed, 42 insertions, 27 deletions
diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py index 264957e0..5d3c3461 100644 --- a/astroid/brain/brain_dataclasses.py +++ b/astroid/brain/brain_dataclasses.py @@ -181,9 +181,12 @@ def _find_arguments_from_base_classes( node: nodes.ClassDef, skippable_names: set[str] ) -> tuple[str, str]: """Iterate through all bases and add them to the list of arguments to add to the init.""" - prev_pos_only = "" - prev_kw_only = "" - for base in node.mro(): + pos_only_store: dict[str, tuple[str | None, str | None]] = {} + kw_only_store: dict[str, tuple[str | None, str | None]] = {} + # See TODO down below + # all_have_defaults = True + + for base in reversed(node.mro()): if not base.is_dataclass: continue try: @@ -191,29 +194,41 @@ def _find_arguments_from_base_classes( except KeyError: continue - # Skip the self argument and check for duplicate arguments - arguments = base_init.args.format_args(skippable_names=skippable_names) - try: - new_prev_pos_only, new_prev_kw_only = arguments.split("*, ") - except ValueError: - new_prev_pos_only, new_prev_kw_only = arguments, "" - - if new_prev_pos_only: - # The split on '*, ' can crete a pos_only string that consists only of a comma - if new_prev_pos_only == ", ": - new_prev_pos_only = "" - elif not new_prev_pos_only.endswith(", "): - new_prev_pos_only += ", " - - # Dataclasses put last seen arguments at the front of the init - prev_pos_only = new_prev_pos_only + prev_pos_only - prev_kw_only = new_prev_kw_only + prev_kw_only - - # Add arguments to skippable arguments - skippable_names.update(arg.name for arg in base_init.args.args) - skippable_names.update(arg.name for arg in base_init.args.kwonlyargs) - - return prev_pos_only, prev_kw_only + pos_only, kw_only = base_init.args._get_arguments_data() + for posarg, data in pos_only.items(): + if posarg in skippable_names: + continue + # if data[1] is None: + # if all_have_defaults and pos_only_store: + # # TODO: This should return an Uninferable as this would raise + # # a TypeError at runtime. However, transforms can't return + # # Uninferables currently. + # pass + # all_have_defaults = False + pos_only_store[posarg] = data + + for kwarg, data in kw_only.items(): + if kwarg in skippable_names: + continue + kw_only_store[kwarg] = data + + pos_only, kw_only = "", "" + for pos_arg, data in pos_only_store.items(): + pos_only += pos_arg + if data[0]: + pos_only += ": " + data[0] + if data[1]: + pos_only += " = " + data[1] + pos_only += ", " + for kw_arg, data in kw_only_store.items(): + kw_only += kw_arg + if data[0]: + kw_only += ": " + data[0] + if data[1]: + kw_only += " = " + data[1] + kw_only += ", " + + return pos_only, kw_only def _generate_dataclass_init( @@ -282,7 +297,7 @@ def _generate_dataclass_init( params_string += ", " if prev_kw_only: - params_string += "*, " + prev_kw_only + ", " + params_string += "*, " + prev_kw_only if kw_only_decorated: params_string += ", ".join(params) + ", " elif kw_only_decorated: |