summaryrefslogtreecommitdiff
path: root/tests/brain/numpy/test_ndarray.py
blob: 9ccadf56739c8a4bd948ce1a8271bfa798cb90a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt

import unittest

try:
    import numpy  # pylint: disable=unused-import

    HAS_NUMPY = True
except ImportError:
    HAS_NUMPY = False

from astroid import builder, nodes
from astroid.brain.brain_numpy_utils import (
    NUMPY_VERSION_TYPE_HINTS_SUPPORT,
    numpy_supports_type_hints,
)


@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class NumpyBrainNdarrayTest(unittest.TestCase):
    """Test that calls to numpy functions returning arrays are correctly inferred."""

    ndarray_returning_ndarray_methods = (
        "__abs__",
        "__add__",
        "__and__",
        "__array__",
        "__array_wrap__",
        "__copy__",
        "__deepcopy__",
        "__eq__",
        "__floordiv__",
        "__ge__",
        "__gt__",
        "__iadd__",
        "__iand__",
        "__ifloordiv__",
        "__ilshift__",
        "__imod__",
        "__imul__",
        "__invert__",
        "__ior__",
        "__ipow__",
        "__irshift__",
        "__isub__",
        "__itruediv__",
        "__ixor__",
        "__le__",
        "__lshift__",
        "__lt__",
        "__matmul__",
        "__mod__",
        "__mul__",
        "__ne__",
        "__neg__",
        "__or__",
        "__pos__",
        "__pow__",
        "__rshift__",
        "__sub__",
        "__truediv__",
        "__xor__",
        "all",
        "any",
        "argmax",
        "argmin",
        "argpartition",
        "argsort",
        "astype",
        "byteswap",
        "choose",
        "clip",
        "compress",
        "conj",
        "conjugate",
        "copy",
        "cumprod",
        "cumsum",
        "diagonal",
        "dot",
        "flatten",
        "getfield",
        "max",
        "mean",
        "min",
        "newbyteorder",
        "prod",
        "ptp",
        "ravel",
        "repeat",
        "reshape",
        "round",
        "searchsorted",
        "squeeze",
        "std",
        "sum",
        "swapaxes",
        "take",
        "trace",
        "transpose",
        "var",
        "view",
    )

    def _inferred_ndarray_method_call(self, func_name):
        node = builder.extract_node(
            f"""
        import numpy as np
        test_array = np.ndarray((2, 2))
        test_array.{func_name:s}()
        """
        )
        return node.infer()

    def _inferred_ndarray_attribute(self, attr_name):
        node = builder.extract_node(
            f"""
        import numpy as np
        test_array = np.ndarray((2, 2))
        test_array.{attr_name:s}
        """
        )
        return node.infer()

    def test_numpy_function_calls_inferred_as_ndarray(self):
        """Test that some calls to numpy functions are inferred as numpy.ndarray."""
        licit_array_types = ".ndarray"
        for func_ in self.ndarray_returning_ndarray_methods:
            with self.subTest(typ=func_):
                inferred_values = list(self._inferred_ndarray_method_call(func_))
                self.assertTrue(
                    len(inferred_values) == 1,
                    msg=f"Too much inferred value for {func_:s}",
                )
                self.assertTrue(
                    inferred_values[-1].pytype() in licit_array_types,
                    msg=f"Illicit type for {func_:s} ({inferred_values[-1].pytype()})",
                )

    def test_numpy_ndarray_attribute_inferred_as_ndarray(self):
        """Test that some numpy ndarray attributes are inferred as numpy.ndarray."""
        licit_array_types = ".ndarray"
        for attr_ in ("real", "imag", "shape", "T"):
            with self.subTest(typ=attr_):
                inferred_values = list(self._inferred_ndarray_attribute(attr_))
                self.assertTrue(
                    len(inferred_values) == 1,
                    msg=f"Too much inferred value for {attr_:s}",
                )
                self.assertTrue(
                    inferred_values[-1].pytype() in licit_array_types,
                    msg=f"Illicit type for {attr_:s} ({inferred_values[-1].pytype()})",
                )

    @unittest.skipUnless(
        HAS_NUMPY and numpy_supports_type_hints(),
        f"This test requires the numpy library with a version above {NUMPY_VERSION_TYPE_HINTS_SUPPORT}",
    )
    def test_numpy_ndarray_class_support_type_indexing(self):
        """Test that numpy ndarray class can be subscripted (type hints)."""
        src = """
        import numpy as np
        np.ndarray[int]
        """
        node = builder.extract_node(src)
        cls_node = node.inferred()[0]
        self.assertIsInstance(cls_node, nodes.ClassDef)
        self.assertEqual(cls_node.name, "ndarray")