diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/blasdot/_dotblas.c | 135 | ||||
-rw-r--r-- | numpy/core/bscript | 16 | ||||
-rw-r--r-- | numpy/core/numeric.py | 18 | ||||
-rw-r--r-- | numpy/core/setup.py | 21 | ||||
-rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 189 | ||||
-rw-r--r-- | numpy/core/tests/test_blasdot.py | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_deprecations.py | 42 |
7 files changed, 226 insertions, 199 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c index 48aa39ff8..c1b4e7b05 100644 --- a/numpy/core/blasdot/_dotblas.c +++ b/numpy/core/blasdot/_dotblas.c @@ -24,121 +24,6 @@ static char module_doc[] = static PyArray_DotFunc *oldFunctions[NPY_NTYPES]; -#define MIN(a, b) ((a) < (b) ? (a) : (b)) - -/* - * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done - * (BLAS won't handle negative or zero strides the way we want). - */ -static NPY_INLINE int -blas_stride(npy_intp stride, unsigned itemsize) -{ - if (stride <= 0 || stride % itemsize != 0) { - return 0; - } - stride /= itemsize; - - if (stride > INT_MAX) { - return 0; - } - return stride; -} - -/* - * The following functions do a "chunked" dot product using BLAS when - * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not - * handle more than INT_MAX elements per call. - * - * The chunksize is the greatest power of two less than INT_MAX. - */ -#if NPY_MAX_INTP > INT_MAX -# define CHUNKSIZE (INT_MAX / 2 + 1) -#else -# define CHUNKSIZE NPY_MAX_INTP -#endif - -static void -FLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, - npy_intp n, void *tmp) -{ - int na = blas_stride(stridea, sizeof(float)); - int nb = blas_stride(strideb, sizeof(float)); - - if (na && nb) { - double r = 0.; /* double for stability */ - float *fa = a, *fb = b; - - while (n > 0) { - int chunk = MIN(n, CHUNKSIZE); - - r += cblas_sdot(chunk, fa, na, fb, nb); - fa += chunk * na; - fb += chunk * nb; - n -= chunk; - } - *((float *)res) = r; - } - else { - oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp); - } -} - -static void -DOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, - npy_intp n, void *tmp) -{ - int na = blas_stride(stridea, sizeof(double)); - int nb = blas_stride(strideb, sizeof(double)); - - if (na && nb) { - double r = 0.; - double *da = a, *db = b; - - while (n > 0) { - int chunk = MIN(n, CHUNKSIZE); - - r += cblas_ddot(chunk, da, na, db, nb); - da += chunk * na; - db += chunk * nb; - n -= chunk; - } - *((double *)res) = r; - } - else { - oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp); - } -} - -static void -CFLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, - npy_intp n, void *tmp) -{ - int na = blas_stride(stridea, sizeof(npy_cfloat)); - int nb = blas_stride(strideb, sizeof(npy_cfloat)); - - if (na && nb) { - cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res); - } - else { - oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp); - } -} - -static void -CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res, - npy_intp n, void *tmp) -{ - int na = blas_stride(stridea, sizeof(npy_cdouble)); - int nb = blas_stride(strideb, sizeof(npy_cdouble)); - - if (na && nb) { - cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, - (double *)res); - } - else { - oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp); - } -} /* * Helper: call appropriate BLAS dot function for typenum. @@ -149,20 +34,8 @@ blas_dot(int typenum, npy_intp n, void *a, npy_intp stridea, void *b, npy_intp strideb, void *res) { PyArray_DotFunc *dot = NULL; - switch (typenum) { - case NPY_DOUBLE: - dot = DOUBLE_dot; - break; - case NPY_FLOAT: - dot = FLOAT_dot; - break; - case NPY_CDOUBLE: - dot = CDOUBLE_dot; - break; - case NPY_CFLOAT: - dot = CFLOAT_dot; - break; - } + + dot = oldFunctions[typenum]; assert(dot != NULL); dot(a, stridea, b, strideb, res, n, NULL); } @@ -257,19 +130,15 @@ dotblas_alterdot(PyObject *NPY_UNUSED(dummy), PyObject *args) if (!altered) { descr = PyArray_DescrFromType(NPY_FLOAT); oldFunctions[NPY_FLOAT] = descr->f->dotfunc; - descr->f->dotfunc = (PyArray_DotFunc *)FLOAT_dot; descr = PyArray_DescrFromType(NPY_DOUBLE); oldFunctions[NPY_DOUBLE] = descr->f->dotfunc; - descr->f->dotfunc = (PyArray_DotFunc *)DOUBLE_dot; descr = PyArray_DescrFromType(NPY_CFLOAT); oldFunctions[NPY_CFLOAT] = descr->f->dotfunc; - descr->f->dotfunc = (PyArray_DotFunc *)CFLOAT_dot; descr = PyArray_DescrFromType(NPY_CDOUBLE); oldFunctions[NPY_CDOUBLE] = descr->f->dotfunc; - descr->f->dotfunc = (PyArray_DotFunc *)CDOUBLE_dot; altered = NPY_TRUE; } diff --git a/numpy/core/bscript b/numpy/core/bscript index 416e16524..3306c5341 100644 --- a/numpy/core/bscript +++ b/numpy/core/bscript @@ -488,14 +488,22 @@ def pre_build(context): pjoin('src', 'multiarray', 'usertypes.c')] else: sources = extension.sources + + use = 'npysort npymath' + defines = ['_FILE_OFFSET_BITS=64', + '_LARGEFILE_SOURCE=1', + '_LARGEFILE64_SOURCE=1'] + + if bld.env.HAS_CBLAS: + use += ' CBLAS' + defines.append('HAVE_CBLAS') + includes = ["src/multiarray", "src/private"] return context.default_builder(extension, includes=includes, source=sources, - use="npysort npymath", - defines=['_FILE_OFFSET_BITS=64', - '_LARGEFILE_SOURCE=1', - '_LARGEFILE64_SOURCE=1'] + use=use, + defines=defines ) context.register_builder("multiarray", builder_multiarray) diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index eb0c38b0b..4361ba5c1 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1078,31 +1078,33 @@ def outer(a, b, out=None): # try to import blas optimized dot if available envbak = os.environ.copy() try: - # importing this changes the dot function for basic 4 types - # to blas-optimized versions. - # disables openblas affinity setting of the main thread that limits # python threads or processes to one core if 'OPENBLAS_MAIN_FREE' not in os.environ: os.environ['OPENBLAS_MAIN_FREE'] = '1' if 'GOTOBLAS_MAIN_FREE' not in os.environ: os.environ['GOTOBLAS_MAIN_FREE'] = '1' - from ._dotblas import dot, vdot, inner, alterdot, restoredot + from ._dotblas import dot, vdot, inner except ImportError: # docstrings are in add_newdocs.py inner = multiarray.inner dot = multiarray.dot def vdot(a, b): return dot(asarray(a).ravel().conj(), asarray(b).ravel()) - def alterdot(): - pass - def restoredot(): - pass finally: os.environ.clear() os.environ.update(envbak) del envbak + +def alterdot(): + warnings.warn("alterdot no longer does anything.", DeprecationWarning) + + +def restoredot(): + warnings.warn("restoredot no longer does anything.", DeprecationWarning) + + def tensordot(a, b, axes=2): """ Compute tensor dot product along specified axes for arrays >= 1-D. diff --git a/numpy/core/setup.py b/numpy/core/setup.py index b81374d14..4c2af5e62 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -846,15 +846,22 @@ def configuration(parent_package='',top_path=None): multiarray_src = [join('src', 'multiarray', 'multiarraymodule_onefile.c')] multiarray_src.append(generate_multiarray_templated_sources) + blas_info = get_info('blas_opt', 0) + if blas_info and ('HAVE_CBLAS', None) in blas_info.get('define_macros', []): + extra_info = blas_info + else: + extra_info = {} + config.add_extension('multiarray', - sources = multiarray_src + + sources=multiarray_src + [generate_config_h, - generate_numpyconfig_h, - generate_numpy_api, - join(codegen_dir, 'generate_numpy_api.py'), - join('*.py')], - depends = deps + multiarray_deps, - libraries = ['npymath', 'npysort']) + generate_numpyconfig_h, + generate_numpy_api, + join(codegen_dir, 'generate_numpy_api.py'), + join('*.py')], + depends=deps + multiarray_deps, + libraries=['npymath', 'npysort'], + extra_info=extra_info) ####################################################################### # umath module # diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index e4904acfc..1b7bbde63 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -3,6 +3,11 @@ #include "Python.h" #include "structmember.h" +#include <limits.h> +#if defined(HAVE_CBLAS) +#include <cblas.h> +#endif + #define NPY_NO_DEPRECATED_API NPY_API_VERSION #define _MULTIARRAYMODULE #include "numpy/npy_common.h" @@ -3094,6 +3099,149 @@ static int * dot means inner product */ +/************************** MAYBE USE CBLAS *********************************/ + +/* + * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done + * (BLAS won't handle negative or zero strides the way we want). + */ +#if defined(HAVE_CBLAS) +static NPY_INLINE int +blas_stride(npy_intp stride, unsigned itemsize) +{ + /* + * Should probably check pointer alignment also, but this may cause + * problems if we require complex to be 16 byte aligned. + */ + if (stride > 0 && npy_is_aligned((void *)stride, itemsize)) { + stride /= itemsize; + if (stride <= INT_MAX) { + return stride; + } + } + return 0; +} +#endif + + +/* + * The following functions do a "chunked" dot product using BLAS when + * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not + * handle more than INT_MAX elements per call. + * + * The chunksize is the greatest power of two less than INT_MAX. + */ +#if NPY_MAX_INTP > INT_MAX +# define CHUNKSIZE (INT_MAX / 2 + 1) +#else +# define CHUNKSIZE NPY_MAX_INTP +#endif + +/**begin repeat + * + * #name = FLOAT, DOUBLE# + * #type = npy_float, npy_double# + * #prefix = s, d# + */ +static void +@name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, + npy_intp n, void *NPY_UNUSED(ignore)) +{ +#if defined(HAVE_CBLAS) + int is1b = blas_stride(is1, sizeof(@type@)); + int is2b = blas_stride(is2, sizeof(@type@)); + + if (is1b && is2b) + { + double sum = 0.; /* double for stability */ + + while (n > 0) { + int chunk = n < CHUNKSIZE ? n : CHUNKSIZE; + + sum += cblas_@prefix@dot(chunk, + (@type@ *) ip1, is1b, + (@type@ *) ip2, is2b); + /* use char strides here */ + ip1 += chunk * is1; + ip2 += chunk * is2; + n -= chunk; + } + *((@type@ *)op) = (@type@)sum; + } + else +#endif + { + @type@ sum = (@type@)0; /* could make this double */ + npy_intp i; + + for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) { + const @type@ ip1r = *((@type@ *)ip1); + const @type@ ip2r = *((@type@ *)ip2); + + sum += ip1r * ip2r; + } + *((@type@ *)op) = sum; + } +} +/**end repeat**/ + +/**begin repeat + * + * #name = CFLOAT, CDOUBLE# + * #ctype = npy_cfloat, npy_cdouble# + * #type = npy_float, npy_double# + * #prefix = c, z# + */ +static void +@name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, + char *op, npy_intp n, void *NPY_UNUSED(ignore)) +{ +#if defined(HAVE_CBLAS) + int is1b = blas_stride(is1, sizeof(@ctype@)); + int is2b = blas_stride(is2, sizeof(@ctype@)); + + if (is1b && is2b) { + double sum[2] = {0., 0.}; /* double for stability */ + + while (n > 0) { + int chunk = n < CHUNKSIZE ? n : CHUNKSIZE; + @type@ tmp[2]; + + cblas_@prefix@dotu_sub((int)n, ip1, is1b, ip2, is2b, tmp); + sum[0] += (double)tmp[0]; + sum[1] += (double)tmp[1]; + /* use char strides here */ + ip1 += chunk * is1; + ip2 += chunk * is2; + n -= chunk; + } + ((@type@ *)op)[0] = (@type@)sum[0]; + ((@type@ *)op)[1] = (@type@)sum[1]; + } + else +#endif + { + @type@ sumr = (@type@)0.0; + @type@ sumi = (@type@)0.0; + npy_intp i; + + for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) { + const @type@ ip1r = ((@type@ *)ip1)[0]; + const @type@ ip1i = ((@type@ *)ip1)[1]; + const @type@ ip2r = ((@type@ *)ip2)[0]; + const @type@ ip2i = ((@type@ *)ip2)[1]; + + sumr += ip1r * ip2r - ip1i * ip2i; + sumi += ip1i * ip2r + ip1r * ip2i; + } + ((@type@ *)op)[0] = sumr; + ((@type@ *)op)[1] = sumi; + } +} +/**end repeat**/ + +/**************************** NO CBLAS VERSIONS *****************************/ + static void BOOL_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n, void *NPY_UNUSED(ignore)) @@ -3114,16 +3262,13 @@ BOOL_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n, * * #name = BYTE, UBYTE, SHORT, USHORT, INT, UINT, * LONG, ULONG, LONGLONG, ULONGLONG, - * FLOAT, DOUBLE, LONGDOUBLE, - * DATETIME, TIMEDELTA# + * LONGDOUBLE, DATETIME, TIMEDELTA# * #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint, * npy_long, npy_ulong, npy_longlong, npy_ulonglong, - * npy_float, npy_double, npy_longdouble, - * npy_datetime, npy_timedelta# + * npy_longdouble, npy_datetime, npy_timedelta# * #out = npy_long, npy_ulong, npy_long, npy_ulong, npy_long, npy_ulong, * npy_long, npy_ulong, npy_longlong, npy_ulonglong, - * npy_float, npy_double, npy_longdouble, - * npy_datetime, npy_timedelta# + * npy_longdouble, npy_datetime, npy_timedelta# */ static void @name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n, @@ -3141,8 +3286,8 @@ static void /**end repeat**/ static void -HALF_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n, - void *NPY_UNUSED(ignore)) +HALF_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, + npy_intp n, void *NPY_UNUSED(ignore)) { float tmp = 0.0f; npy_intp i; @@ -3154,28 +3299,26 @@ HALF_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n, *((npy_half *)op) = npy_float_to_half(tmp); } -/**begin repeat - * - * #name = CFLOAT, CDOUBLE, CLONGDOUBLE# - * #type = npy_float, npy_double, npy_longdouble# - */ -static void @name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, - char *op, npy_intp n, void *NPY_UNUSED(ignore)) +static void CLONGDOUBLE_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, + char *op, npy_intp n, void *NPY_UNUSED(ignore)) { - @type@ tmpr = (@type@)0.0, tmpi=(@type@)0.0; + npy_longdouble tmpr = 0.0L; + npy_longdouble tmpi = 0.0L; npy_intp i; for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) { - tmpr += ((@type@ *)ip1)[0] * ((@type@ *)ip2)[0] - - ((@type@ *)ip1)[1] * ((@type@ *)ip2)[1]; - tmpi += ((@type@ *)ip1)[1] * ((@type@ *)ip2)[0] - + ((@type@ *)ip1)[0] * ((@type@ *)ip2)[1]; + const npy_longdouble ip1r = ((npy_longdouble *)ip1)[0]; + const npy_longdouble ip1i = ((npy_longdouble *)ip1)[1]; + const npy_longdouble ip2r = ((npy_longdouble *)ip2)[0]; + const npy_longdouble ip2i = ((npy_longdouble *)ip2)[1]; + + tmpr += ip1r * ip2r - ip1i * ip2i; + tmpi += ip1i * ip2r + ip1r * ip2i; } - ((@type@ *)op)[0] = tmpr; ((@type@ *)op)[1] = tmpi; + ((npy_longdouble *)op)[0] = tmpr; + ((npy_longdouble *)op)[1] = tmpi; } -/**end repeat**/ - static void OBJECT_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n, void *NPY_UNUSED(ignore)) diff --git a/numpy/core/tests/test_blasdot.py b/numpy/core/tests/test_blasdot.py index 17f77d2f5..c38dab187 100644 --- a/numpy/core/tests/test_blasdot.py +++ b/numpy/core/tests/test_blasdot.py @@ -26,12 +26,10 @@ except ImportError: @dec.skipif(_dotblas is None, "Numpy is not compiled with _dotblas") def test_blasdot_used(): - from numpy.core import dot, vdot, inner, alterdot, restoredot + from numpy.core import dot, vdot, inner assert_(dot is _dotblas.dot) assert_(vdot is _dotblas.vdot) assert_(inner is _dotblas.inner) - assert_(alterdot is _dotblas.alterdot) - assert_(restoredot is _dotblas.restoredot) def test_dot_2args(): diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py index ef56766f5..9e2248205 100644 --- a/numpy/core/tests/test_deprecations.py +++ b/numpy/core/tests/test_deprecations.py @@ -5,13 +5,11 @@ to document how deprecations should eventually be turned into errors. """ from __future__ import division, absolute_import, print_function -import sys import operator import warnings -from nose.plugins.skip import SkipTest import numpy as np -from numpy.testing import (dec, run_module_suite, assert_raises, +from numpy.testing import (run_module_suite, assert_raises, assert_warns, assert_array_equal, assert_) @@ -34,11 +32,9 @@ class _DeprecationTestCase(object): warnings.filterwarnings("always", message=self.message, category=DeprecationWarning) - def tearDown(self): self.warn_ctx.__exit__() - def assert_deprecated(self, function, num=1, ignore_others=False, function_fails=False, exceptions=(DeprecationWarning,), args=(), kwargs={}): @@ -102,7 +98,6 @@ class _DeprecationTestCase(object): if exceptions == tuple(): raise AssertionError("Error raised during function call") - def assert_not_deprecated(self, function, args=(), kwargs={}): """Test if DeprecationWarnings are given and raised. @@ -143,6 +138,7 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase): def test_indexing(self): a = np.array([[[5]]]) + def assert_deprecated(*args, **kwargs): self.assert_deprecated(*args, exceptions=(IndexError,), **kwargs) @@ -172,7 +168,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase): assert_deprecated(lambda: a[0.0:, 0.0], num=2) assert_deprecated(lambda: a[0.0:, 0.0,:], num=2) - def test_valid_indexing(self): a = np.array([[[5]]]) assert_not_deprecated = self.assert_not_deprecated @@ -183,9 +178,9 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase): assert_not_deprecated(lambda: a[:, 0,:]) assert_not_deprecated(lambda: a[:,:,:]) - def test_slicing(self): a = np.array([[5]]) + def assert_deprecated(*args, **kwargs): self.assert_deprecated(*args, exceptions=(IndexError,), **kwargs) @@ -217,7 +212,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase): # should still get the DeprecationWarning if step = 0. assert_deprecated(lambda: a[::0.0], function_fails=True) - def test_valid_slicing(self): a = np.array([[[5]]]) assert_not_deprecated = self.assert_not_deprecated @@ -231,7 +225,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase): assert_not_deprecated(lambda: a[:2:2]) assert_not_deprecated(lambda: a[1:2:2]) - def test_non_integer_argument_deprecations(self): a = np.array([[5]]) @@ -240,7 +233,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase): self.assert_deprecated(np.take, args=(a, [0], 1.)) self.assert_deprecated(np.take, args=(a, [0], np.float64(1.))) - def test_non_integer_sequence_multiplication(self): # Numpy scalar sequence multiply should not work with non-integers def mult(a, b): @@ -248,7 +240,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase): self.assert_deprecated(mult, args=([1], np.float_(3))) self.assert_not_deprecated(mult, args=([1], np.int_(3))) - def test_reduce_axis_float_index(self): d = np.zeros((3,3,3)) self.assert_deprecated(np.min, args=(d, 0.5)) @@ -303,7 +294,6 @@ class TestArrayToIndexDeprecation(_DeprecationTestCase): # Check slicing. Normal indexing checks arrays specifically. self.assert_deprecated(lambda: a[a:a:a], exceptions=(), num=3) - class TestNonIntegerArrayLike(_DeprecationTestCase): """Tests that array likes, i.e. lists give a deprecation warning when they cannot be safely cast to an integer. @@ -320,7 +310,6 @@ class TestNonIntegerArrayLike(_DeprecationTestCase): self.assert_not_deprecated(a.__getitem__, ([],)) - def test_boolean_futurewarning(self): a = np.arange(10) with warnings.catch_warnings(): @@ -378,12 +367,13 @@ class TestRankDeprecation(_DeprecationTestCase): """Test that np.rank is deprecated. The function should simply be removed. The VisibleDeprecationWarning may become unnecessary. """ + def test(self): a = np.arange(10) assert_warns(np.VisibleDeprecationWarning, np.rank, a) -class TestComparisonDepreactions(_DeprecationTestCase): +class TestComparisonDeprecations(_DeprecationTestCase): """This tests the deprecation, for non-elementwise comparison logic. This used to mean that when an error occured during element-wise comparison (i.e. broadcasting) NotImplemented was returned, but also in the comparison @@ -408,7 +398,6 @@ class TestComparisonDepreactions(_DeprecationTestCase): b = np.array([1, np.array([1,2,3])], dtype=object) self.assert_deprecated(op, args=(a, b), num=None) - def test_string(self): # For two string arrays, strings always raised the broadcasting error: a = np.array(['a', 'b']) @@ -420,7 +409,6 @@ class TestComparisonDepreactions(_DeprecationTestCase): # following works (and returns False) due to dtype mismatch: a == [] - def test_none_comparison(self): # Test comparison of None, which should result in elementwise # comparison in the future. [1, 2] == None should be [False, False]. @@ -455,14 +443,14 @@ class TestComparisonDepreactions(_DeprecationTestCase): assert_(np.equal(np.datetime64('NaT'), None)) -class TestIdentityComparisonDepreactions(_DeprecationTestCase): +class TestIdentityComparisonDeprecations(_DeprecationTestCase): """This tests the equal and not_equal object ufuncs identity check deprecation. This was due to the usage of PyObject_RichCompareBool. This tests that for example for `a = np.array([np.nan], dtype=object)` `a == a` it is warned that False and not `np.nan is np.nan` is returned. - Should be kept in sync with TestComparisonDepreactions and new tests + Should be kept in sync with TestComparisonDeprecations and new tests added when the deprecation is over. Requires only removing of @identity@ (and blocks) from the ufunc loops.c.src of the OBJECT comparisons. """ @@ -488,11 +476,11 @@ class TestIdentityComparisonDepreactions(_DeprecationTestCase): np.less_equal(a, a) np.greater_equal(a, a) - def test_comparison_error(self): class FunkyType(object): def __eq__(self, other): raise TypeError("I won't compare") + def __ne__(self, other): raise TypeError("I won't compare") @@ -500,7 +488,6 @@ class TestIdentityComparisonDepreactions(_DeprecationTestCase): self.assert_deprecated(np.equal, args=(a, a)) self.assert_deprecated(np.not_equal, args=(a, a)) - def test_bool_error(self): # The comparison result cannot be interpreted as a bool a = np.array([np.array([1, 2, 3]), None], dtype=object) @@ -508,5 +495,18 @@ class TestIdentityComparisonDepreactions(_DeprecationTestCase): self.assert_deprecated(np.not_equal, args=(a, a)) +class TestAlterdotRestoredotDeprecations(_DeprecationTestCase): + """The alterdot/restoredot functions are deprecated. + + These functions no longer do anything in numpy 1.10, so should not be + used. + + """ + + def test_alterdot_restoredot_deprecation(self): + self.assert_deprecated(np.alterdot) + self.assert_deprecated(np.restoredot) + + if __name__ == "__main__": run_module_suite() |