summaryrefslogtreecommitdiff
path: root/pylint/testutils/pyreverse.py
blob: 24fddad770e3b56347c99ee06e29ddd4ec872541 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
# For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt

from __future__ import annotations

import argparse
import configparser
import shlex
import sys
from pathlib import Path
from typing import NamedTuple

from pylint.pyreverse.main import DEFAULT_COLOR_PALETTE

if sys.version_info >= (3, 8):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict


# This class could and should be replaced with a simple dataclass when support for Python < 3.7 is dropped.
# A NamedTuple is not possible as some tests need to modify attributes during the test.
class PyreverseConfig(
    argparse.Namespace
):  # pylint: disable=too-many-instance-attributes, too-many-arguments
    """Holds the configuration options for Pyreverse.

    The default values correspond to the defaults of the options' parser.
    """

    def __init__(
        self,
        mode: str = "PUB_ONLY",
        classes: list[str] | None = None,
        show_ancestors: int | None = None,
        all_ancestors: bool | None = None,
        show_associated: int | None = None,
        all_associated: bool | None = None,
        no_standalone: bool = False,
        show_builtin: bool = False,
        show_stdlib: bool = False,
        module_names: bool | None = None,
        only_classnames: bool = False,
        output_format: str = "dot",
        colorized: bool = False,
        max_color_depth: int = 2,
        color_palette: tuple[str, ...] = DEFAULT_COLOR_PALETTE,
        ignore_list: tuple[str, ...] = tuple(),
        project: str = "",
        output_directory: str = "",
    ) -> None:
        super().__init__()
        self.mode = mode
        if classes:
            self.classes = classes
        else:
            self.classes = []
        self.show_ancestors = show_ancestors
        self.all_ancestors = all_ancestors
        self.show_associated = show_associated
        self.all_associated = all_associated
        self.no_standalone = no_standalone
        self.show_builtin = show_builtin
        self.show_stdlib = show_stdlib
        self.module_names = module_names
        self.only_classnames = only_classnames
        self.output_format = output_format
        self.colorized = colorized
        self.max_color_depth = max_color_depth
        self.color_palette = color_palette
        self.ignore_list = ignore_list
        self.project = project
        self.output_directory = output_directory


class TestFileOptions(TypedDict):
    source_roots: list[str]
    output_formats: list[str]
    command_line_args: list[str]


class FunctionalPyreverseTestfile(NamedTuple):
    """Named tuple containing the test file and the expected output."""

    source: Path
    options: TestFileOptions


def get_functional_test_files(
    root_directory: Path,
) -> list[FunctionalPyreverseTestfile]:
    """Get all functional test files from the given directory."""
    test_files = []
    for path in root_directory.rglob("*.py"):
        if path.stem.startswith("_"):
            continue
        config_file = path.with_suffix(".rc")
        if config_file.exists():
            test_files.append(
                FunctionalPyreverseTestfile(
                    source=path, options=_read_config(config_file)
                )
            )
        else:
            test_files.append(
                FunctionalPyreverseTestfile(
                    source=path,
                    options={
                        "source_roots": [],
                        "output_formats": ["mmd"],
                        "command_line_args": [],
                    },
                )
            )
    return test_files


def _read_config(config_file: Path) -> TestFileOptions:
    config = configparser.ConfigParser()
    config.read(str(config_file))
    source_roots = config.get("testoptions", "source_roots", fallback=None)
    return {
        "source_roots": source_roots.split(",") if source_roots else [],
        "output_formats": config.get(
            "testoptions", "output_formats", fallback="mmd"
        ).split(","),
        "command_line_args": shlex.split(
            config.get("testoptions", "command_line_args", fallback="")
        ),
    }