summaryrefslogtreecommitdiff
path: root/astroid/brain/brain_dataclasses.py
diff options
context:
space:
mode:
Diffstat (limited to 'astroid/brain/brain_dataclasses.py')
-rw-r--r--astroid/brain/brain_dataclasses.py69
1 files changed, 42 insertions, 27 deletions
diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py
index 264957e0..5d3c3461 100644
--- a/astroid/brain/brain_dataclasses.py
+++ b/astroid/brain/brain_dataclasses.py
@@ -181,9 +181,12 @@ def _find_arguments_from_base_classes(
node: nodes.ClassDef, skippable_names: set[str]
) -> tuple[str, str]:
"""Iterate through all bases and add them to the list of arguments to add to the init."""
- prev_pos_only = ""
- prev_kw_only = ""
- for base in node.mro():
+ pos_only_store: dict[str, tuple[str | None, str | None]] = {}
+ kw_only_store: dict[str, tuple[str | None, str | None]] = {}
+ # See TODO down below
+ # all_have_defaults = True
+
+ for base in reversed(node.mro()):
if not base.is_dataclass:
continue
try:
@@ -191,29 +194,41 @@ def _find_arguments_from_base_classes(
except KeyError:
continue
- # Skip the self argument and check for duplicate arguments
- arguments = base_init.args.format_args(skippable_names=skippable_names)
- try:
- new_prev_pos_only, new_prev_kw_only = arguments.split("*, ")
- except ValueError:
- new_prev_pos_only, new_prev_kw_only = arguments, ""
-
- if new_prev_pos_only:
- # The split on '*, ' can crete a pos_only string that consists only of a comma
- if new_prev_pos_only == ", ":
- new_prev_pos_only = ""
- elif not new_prev_pos_only.endswith(", "):
- new_prev_pos_only += ", "
-
- # Dataclasses put last seen arguments at the front of the init
- prev_pos_only = new_prev_pos_only + prev_pos_only
- prev_kw_only = new_prev_kw_only + prev_kw_only
-
- # Add arguments to skippable arguments
- skippable_names.update(arg.name for arg in base_init.args.args)
- skippable_names.update(arg.name for arg in base_init.args.kwonlyargs)
-
- return prev_pos_only, prev_kw_only
+ pos_only, kw_only = base_init.args._get_arguments_data()
+ for posarg, data in pos_only.items():
+ if posarg in skippable_names:
+ continue
+ # if data[1] is None:
+ # if all_have_defaults and pos_only_store:
+ # # TODO: This should return an Uninferable as this would raise
+ # # a TypeError at runtime. However, transforms can't return
+ # # Uninferables currently.
+ # pass
+ # all_have_defaults = False
+ pos_only_store[posarg] = data
+
+ for kwarg, data in kw_only.items():
+ if kwarg in skippable_names:
+ continue
+ kw_only_store[kwarg] = data
+
+ pos_only, kw_only = "", ""
+ for pos_arg, data in pos_only_store.items():
+ pos_only += pos_arg
+ if data[0]:
+ pos_only += ": " + data[0]
+ if data[1]:
+ pos_only += " = " + data[1]
+ pos_only += ", "
+ for kw_arg, data in kw_only_store.items():
+ kw_only += kw_arg
+ if data[0]:
+ kw_only += ": " + data[0]
+ if data[1]:
+ kw_only += " = " + data[1]
+ kw_only += ", "
+
+ return pos_only, kw_only
def _generate_dataclass_init(
@@ -282,7 +297,7 @@ def _generate_dataclass_init(
params_string += ", "
if prev_kw_only:
- params_string += "*, " + prev_kw_only + ", "
+ params_string += "*, " + prev_kw_only
if kw_only_decorated:
params_string += ", ".join(params) + ", "
elif kw_only_decorated: