diff options
author | Ashley Whetter <AWhetter@users.noreply.github.com> | 2019-10-15 01:49:26 -0700 |
---|---|---|
committer | Claudiu Popa <pcmanticore@gmail.com> | 2019-10-15 10:49:26 +0200 |
commit | 2f288598de485c6af25788fc917139b48c31c474 (patch) | |
tree | 3b52b2994c90018a2db2854adca0928c4bfe1162 /tests/unittest_brain_numpy_ndarray.py | |
parent | 73babe3d536ffc4da94e59c705eb6a8c3e5822ef (diff) | |
download | astroid-git-2f288598de485c6af25788fc917139b48c31c474.tar.gz |
Moved tests out of package directory (#704)
Diffstat (limited to 'tests/unittest_brain_numpy_ndarray.py')
-rw-r--r-- | tests/unittest_brain_numpy_ndarray.py | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/tests/unittest_brain_numpy_ndarray.py b/tests/unittest_brain_numpy_ndarray.py new file mode 100644 index 00000000..d982f7f6 --- /dev/null +++ b/tests/unittest_brain_numpy_ndarray.py @@ -0,0 +1,141 @@ +# -*- encoding=utf-8 -*- +# Copyright (c) 2017-2018 hippo91 <guillaume.peillex@gmail.com> +# Copyright (c) 2017 Claudiu Popa <pcmanticore@gmail.com> +# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com> + +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER +import unittest + +try: + import numpy # pylint: disable=unused-import + + HAS_NUMPY = True +except ImportError: + HAS_NUMPY = False + +from astroid import builder + + +@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.") +class NumpyBrainNdarrayTest(unittest.TestCase): + """ + Test that calls to numpy functions returning arrays are correctly inferred + """ + + ndarray_returning_ndarray_methods = ( + "__abs__", + "__add__", + "__and__", + "__array__", + "__array_wrap__", + "__copy__", + "__deepcopy__", + "__eq__", + "__floordiv__", + "__ge__", + "__gt__", + "__iadd__", + "__iand__", + "__ifloordiv__", + "__ilshift__", + "__imod__", + "__imul__", + "__invert__", + "__ior__", + "__ipow__", + "__irshift__", + "__isub__", + "__itruediv__", + "__ixor__", + "__le__", + "__lshift__", + "__lt__", + "__matmul__", + "__mod__", + "__mul__", + "__ne__", + "__neg__", + "__or__", + "__pos__", + "__pow__", + "__rshift__", + "__sub__", + "__truediv__", + "__xor__", + "all", + "any", + "argmax", + "argmin", + "argpartition", + "argsort", + "astype", + "byteswap", + "choose", + "clip", + "compress", + "conj", + "conjugate", + "copy", + "cumprod", + "cumsum", + "diagonal", + "dot", + "flatten", + "getfield", + "max", + "mean", + "min", + "newbyteorder", + "prod", + "ptp", + "ravel", + "repeat", + "reshape", + "round", + "searchsorted", + "squeeze", + "std", + "sum", + "swapaxes", + "take", + "trace", + "transpose", + "var", + "view", + ) + + def _inferred_ndarray_method_call(self, func_name): + node = builder.extract_node( + """ + import numpy as np + test_array = np.ndarray((2, 2)) + test_array.{:s}() + """.format( + func_name + ) + ) + return node.infer() + + def test_numpy_function_calls_inferred_as_ndarray(self): + """ + Test that some calls to numpy functions are inferred as numpy.ndarray + """ + licit_array_types = ".ndarray" + for func_ in self.ndarray_returning_ndarray_methods: + with self.subTest(typ=func_): + inferred_values = list(self._inferred_ndarray_method_call(func_)) + self.assertTrue( + len(inferred_values) == 1, + msg="Too much inferred value for {:s}".format(func_), + ) + self.assertTrue( + inferred_values[-1].pytype() in licit_array_types, + msg="Illicit type for {:s} ({})".format( + func_, inferred_values[-1].pytype() + ), + ) + + +if __name__ == "__main__": + unittest.main() |