summaryrefslogtreecommitdiff
path: root/astroid
diff options
context:
space:
mode:
Diffstat (limited to 'astroid')
-rw-r--r--astroid/brain/brain_dataclasses.py12
-rw-r--r--astroid/brain/brain_namedtuple_enum.py14
-rw-r--r--astroid/manager.py2
3 files changed, 18 insertions, 10 deletions
diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py
index e010a514..0bd394e3 100644
--- a/astroid/brain/brain_dataclasses.py
+++ b/astroid/brain/brain_dataclasses.py
@@ -12,10 +12,16 @@ from typing import Generator, List, Optional, Tuple
from astroid import context, inference_tip
from astroid.builder import parse
from astroid.const import PY37_PLUS, PY39_PLUS
-from astroid.exceptions import AstroidSyntaxError, InferenceError, MroError
+from astroid.exceptions import (
+ AstroidSyntaxError,
+ InferenceError,
+ MroError,
+ UseInferenceDefault,
+)
from astroid.manager import AstroidManager
from astroid.nodes.node_classes import (
AnnAssign,
+ Assign,
AssignName,
Attribute,
Call,
@@ -231,9 +237,11 @@ def infer_dataclass_attribute(
def infer_dataclass_field_call(
- node: AssignName, ctx: context.InferenceContext = None
+ node: Call, ctx: Optional[context.InferenceContext] = None
) -> Generator:
"""Inference tip for dataclass field calls."""
+ if not isinstance(node.parent, (AnnAssign, Assign)):
+ raise UseInferenceDefault
field_call = node.parent.value
default_type, default = _get_field_default(field_call)
if not default_type:
diff --git a/astroid/brain/brain_namedtuple_enum.py b/astroid/brain/brain_namedtuple_enum.py
index f79e44cc..4ed105ca 100644
--- a/astroid/brain/brain_namedtuple_enum.py
+++ b/astroid/brain/brain_namedtuple_enum.py
@@ -484,13 +484,13 @@ def infer_typing_namedtuple_class(class_node, context=None):
for method in class_node.mymethods():
generated_class_node.locals[method.name] = [method]
- for assign in class_node.body:
- if not isinstance(assign, nodes.Assign):
- continue
-
- for target in assign.targets:
- attr = target.name
- generated_class_node.locals[attr] = class_node.locals[attr]
+ for body_node in class_node.body:
+ if isinstance(body_node, nodes.Assign):
+ for target in body_node.targets:
+ attr = target.name
+ generated_class_node.locals[attr] = class_node.locals[attr]
+ elif isinstance(body_node, nodes.ClassDef):
+ generated_class_node.locals[body_node.name] = [body_node]
return iter((generated_class_node,))
diff --git a/astroid/manager.py b/astroid/manager.py
index 5575151e..89ef2ac6 100644
--- a/astroid/manager.py
+++ b/astroid/manager.py
@@ -46,7 +46,7 @@ from astroid.modutils import (
)
from astroid.transforms import TransformVisitor
-ZIP_IMPORT_EXTS = (".zip", ".egg", ".whl")
+ZIP_IMPORT_EXTS = (".zip", ".egg", ".whl", ".pyz", ".pyzw")
def safe_repr(obj):