summaryrefslogtreecommitdiff
path: root/tests/pyreverse/test_utils.py
blob: d64bf4fa70e5eba9e9a7d2246bded08074a3a7ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
# For details: https://github.com/PyCQA/pylint/blob/main/LICENSE
# Copyright (c) https://github.com/PyCQA/pylint/blob/main/CONTRIBUTORS.txt

"""Tests for pylint.pyreverse.utils."""

from __future__ import annotations

from typing import Any
from unittest.mock import patch

import astroid
import pytest
from astroid import nodes

from pylint.pyreverse.utils import (
    get_annotation,
    get_annotation_label,
    get_visibility,
    infer_node,
)


@pytest.mark.parametrize(
    "names, expected",
    [
        (["__reduce_ex__", "__setattr__"], "special"),
        (["__g_", "____dsf", "__23_9"], "private"),
        (["simple"], "public"),
        (
            ["_", "__", "___", "____", "_____", "___e__", "_nextsimple", "_filter_it_"],
            "protected",
        ),
    ],
)
def test_get_visibility(names: list[str], expected: str) -> None:
    for name in names:
        got = get_visibility(name)
        assert got == expected, f"got {got} instead of {expected} for value {name}"


@pytest.mark.parametrize(
    "assign, label",
    [
        ("a: str = None", "Optional[str]"),
        ("a: str = 'mystr'", "str"),
        ("a: Optional[str] = 'str'", "Optional[str]"),
        ("a: Optional[str] = None", "Optional[str]"),
    ],
)
def test_get_annotation_annassign(assign: str, label: str) -> None:
    """AnnAssign."""
    node: nodes.AnnAssign = astroid.extract_node(assign)
    annotation = get_annotation(node.value)
    assert annotation is not None
    got = annotation.name
    assert isinstance(node, nodes.AnnAssign)
    assert got == label, f"got {got} instead of {label} for value {node}"


@pytest.mark.parametrize(
    "init_method, label",
    [
        ("def __init__(self, x: str):                   self.x = x", "str"),
        ("def __init__(self, x: str = 'str'):           self.x = x", "str"),
        ("def __init__(self, x: str = None):            self.x = x", "Optional[str]"),
        ("def __init__(self, x: Optional[str]):         self.x = x", "Optional[str]"),
        ("def __init__(self, x: Optional[str] = None):  self.x = x", "Optional[str]"),
        ("def __init__(self, x: Optional[str] = 'str'): self.x = x", "Optional[str]"),
    ],
)
def test_get_annotation_assignattr(init_method: str, label: str) -> None:
    """AssignAttr."""
    assign = rf"""
        class A:
            {init_method}
    """
    node = astroid.extract_node(assign)
    instance_attrs = node.instance_attrs
    for assign_attrs in instance_attrs.values():
        for assign_attr in assign_attrs:
            annotation = get_annotation(assign_attr)
            assert annotation is not None
            got = annotation.name
            assert isinstance(assign_attr, nodes.AssignAttr)
            assert got == label, f"got {got} instead of {label} for value {node}"


@pytest.mark.parametrize(
    "node_text, expected_label",
    [
        ("def f() -> None: pass", "None"),
        ("def f() -> int: pass", "int"),
        ("def f(a) -> Optional[int]: return 1 if a else None", "Optional[int]"),
        ("def f() -> 'MyType': pass", "'MyType'"),
    ],
)
def test_get_annotation_label_of_return_type(
    node_text: str, expected_label: str
) -> None:
    func = astroid.extract_node(node_text)
    assert isinstance(func, nodes.FunctionDef)
    assert get_annotation_label(func.returns) == expected_label


@patch("pylint.pyreverse.utils.get_annotation")
@patch("astroid.nodes.NodeNG.infer", side_effect=astroid.InferenceError)
def test_infer_node_1(mock_infer: Any, mock_get_annotation: Any) -> None:
    """Return set() when astroid.InferenceError is raised and an annotation has
    not been returned
    """
    mock_get_annotation.return_value = None
    node = astroid.extract_node("a: str = 'mystr'")
    mock_infer.return_value = "x"
    assert infer_node(node) == set()
    assert mock_infer.called


@patch("pylint.pyreverse.utils.get_annotation")
@patch("astroid.nodes.NodeNG.infer")
def test_infer_node_2(mock_infer: Any, mock_get_annotation: Any) -> None:
    """Return set(node.infer()) when InferenceError is not raised and an
    annotation has not been returned
    """
    mock_get_annotation.return_value = None
    node = astroid.extract_node("a: str = 'mystr'")
    mock_infer.return_value = "x"
    assert infer_node(node) == set("x")
    assert mock_infer.called


def test_infer_node_3() -> None:
    """Return a set containing a nodes.ClassDef object when the attribute
    has a type annotation
    """
    node = astroid.extract_node(
        """
        class Component:
            pass

        class Composite:
            def __init__(self, component: Component):
                self.component = component
    """
    )
    instance_attr = node.instance_attrs.get("component")[0]
    assert isinstance(infer_node(instance_attr), set)
    assert isinstance(infer_node(instance_attr).pop(), nodes.ClassDef)


def test_infer_node_4() -> None:
    """Verify the label for an argument with a typehint of the type
    nodes.Subscript
    """
    node = astroid.extract_node(
        """
        class MyClass:
            def __init__(self, my_int: Optional[int] = None):
                self.my_test_int = my_int
    """
    )

    instance_attr = node.instance_attrs.get("my_test_int")[0]
    assert isinstance(instance_attr, nodes.AssignAttr)

    inferred = infer_node(instance_attr).pop()
    assert isinstance(inferred, nodes.Subscript)
    assert inferred.name == "Optional[int]"