diff options
author | Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com> | 2021-12-29 21:23:18 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-29 21:23:18 +0100 |
commit | 0d1211558670cfefd95b39984b8d5f7f34837f32 (patch) | |
tree | a2e34269397547f2bdc5a9463538c87d7a95791d /astroid/brain/brain_dataclasses.py | |
parent | aa4f5bed58dc6521a6c2c0927ca0e0da48fd5ea5 (diff) | |
download | astroid-git-0d1211558670cfefd95b39984b8d5f7f34837f32.tar.gz |
Add typing to ``brain_dataclasses`` (#1292)
Co-authored-by: Pierre Sassoulas <pierre.sassoulas@gmail.com>
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Diffstat (limited to 'astroid/brain/brain_dataclasses.py')
-rw-r--r-- | astroid/brain/brain_dataclasses.py | 72 |
1 files changed, 44 insertions, 28 deletions
diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py index a05c46f9..bfdbbe09 100644 --- a/astroid/brain/brain_dataclasses.py +++ b/astroid/brain/brain_dataclasses.py @@ -10,7 +10,8 @@ dataclasses. References: - https://lovasoa.github.io/marshmallow_dataclass/ """ -from typing import FrozenSet, Generator, List, Optional, Tuple +import sys +from typing import FrozenSet, Generator, List, Optional, Tuple, Union from astroid import context, inference_tip from astroid.builder import parse @@ -36,6 +37,15 @@ from astroid.nodes.node_classes import ( from astroid.nodes.scoped_nodes import ClassDef, FunctionDef from astroid.util import Uninferable +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +_FieldDefaultReturn = Union[ + None, Tuple[Literal["default"], NodeNG], Tuple[Literal["default_factory"], Call] +] + DATACLASSES_DECORATORS = frozenset(("dataclass",)) FIELD_NAME = "field" DATACLASS_MODULES = frozenset( @@ -115,7 +125,7 @@ def _get_dataclass_attributes(node: ClassDef, init: bool = False) -> Generator: ): continue - if _is_class_var(assign_node.annotation): + if _is_class_var(assign_node.annotation): # type: ignore[arg-type] # annotation is never None continue if init: @@ -124,12 +134,13 @@ def _get_dataclass_attributes(node: ClassDef, init: bool = False) -> Generator: isinstance(value, Call) and _looks_like_dataclass_field_call(value, check_scope=False) and any( - keyword.arg == "init" and not keyword.value.bool_value() + keyword.arg == "init" + and not keyword.value.bool_value() # type: ignore[union-attr] # value is never None for keyword in value.keywords ) ): continue - elif _is_init_var(assign_node.annotation): + elif _is_init_var(assign_node.annotation): # type: ignore[arg-type] # annotation is never None continue yield assign_node @@ -159,7 +170,8 @@ def _check_generate_dataclass_init(node: ClassDef) -> bool: # Check for keyword arguments of the form init=False return all( - keyword.arg != "init" or keyword.value.bool_value() + keyword.arg != "init" + and keyword.value.bool_value() # type: ignore[union-attr] # value is never None for keyword in found.keywords ) @@ -174,7 +186,7 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str: name, annotation, value = assign.target.name, assign.annotation, assign.value target_names.append(name) - if _is_init_var(annotation): + if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None init_var = True if isinstance(annotation, Subscript): annotation = annotation.slice @@ -196,16 +208,16 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str: value, check_scope=False ): result = _get_field_default(value) - - default_type, default_node = result - if default_type == "default": - param_str += f" = {default_node.as_string()}" - elif default_type == "default_factory": - param_str += f" = {DEFAULT_FACTORY}" - assignment_str = ( - f"self.{name} = {default_node.as_string()} " - f"if {name} is {DEFAULT_FACTORY} else {name}" - ) + if result: + default_type, default_node = result + if default_type == "default": + param_str += f" = {default_node.as_string()}" + elif default_type == "default_factory": + param_str += f" = {DEFAULT_FACTORY}" + assignment_str = ( + f"self.{name} = {default_node.as_string()} " + f"if {name} is {DEFAULT_FACTORY} else {name}" + ) else: param_str += f" = {value.as_string()}" @@ -219,7 +231,7 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str: def infer_dataclass_attribute( - node: Unknown, ctx: context.InferenceContext = None + node: Unknown, ctx: Optional[context.InferenceContext] = None ) -> Generator: """Inference tip for an Unknown node that was dynamically generated to represent a dataclass attribute. @@ -247,16 +259,17 @@ def infer_dataclass_field_call( """Inference tip for dataclass field calls.""" if not isinstance(node.parent, (AnnAssign, Assign)): raise UseInferenceDefault - field_call = node.parent.value - default_type, default = _get_field_default(field_call) - if not default_type: + result = _get_field_default(node) + if not result: yield Uninferable - elif default_type == "default": - yield from default.infer(context=ctx) else: - new_call = parse(default.as_string()).body[0].value - new_call.parent = field_call.parent - yield from new_call.infer(context=ctx) + default_type, default = result + if default_type == "default": + yield from default.infer(context=ctx) + else: + new_call = parse(default.as_string()).body[0].value + new_call.parent = node.parent + yield from new_call.infer(context=ctx) def _looks_like_dataclass_decorator( @@ -294,6 +307,9 @@ def _looks_like_dataclass_attribute(node: Unknown) -> bool: statement. """ parent = node.parent + if not parent: + return False + scope = parent.scope() return ( isinstance(parent, AnnAssign) @@ -330,7 +346,7 @@ def _looks_like_dataclass_field_call(node: Call, check_scope: bool = True) -> bo return inferred.name == FIELD_NAME and inferred.root().name in DATACLASS_MODULES -def _get_field_default(field_call: Call) -> Tuple[str, Optional[NodeNG]]: +def _get_field_default(field_call: Call) -> _FieldDefaultReturn: """Return a the default value of a field call, and the corresponding keyword argument name. field(default=...) results in the ... node @@ -358,7 +374,7 @@ def _get_field_default(field_call: Call) -> Tuple[str, Optional[NodeNG]]: new_call.postinit(func=default_factory) return "default_factory", new_call - return "", None + return None def _is_class_var(node: NodeNG) -> bool: @@ -404,7 +420,7 @@ _INFERABLE_TYPING_TYPES = frozenset( def _infer_instance_from_annotation( - node: NodeNG, ctx: context.InferenceContext = None + node: NodeNG, ctx: Optional[context.InferenceContext] = None ) -> Generator: """Infer an instance corresponding to the type annotation represented by node. |