diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 15 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 24 |
2 files changed, 30 insertions, 9 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index 3b1b70a68..1a29e6904 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -1131,9 +1131,18 @@ C@TYPE@_floor_divide(char **args, intp *dimensions, intp *steps, void *NPY_UNUSE const @type@ in1i = ((@type@ *)ip1)[1]; const @type@ in2r = ((@type@ *)ip2)[0]; const @type@ in2i = ((@type@ *)ip2)[1]; - @type@ d = in2r*in2r + in2i*in2i; - ((@type@ *)op1)[0] = npy_floor@c@((in1r*in2r + in1i*in2i)/d); - ((@type@ *)op1)[1] = 0; + if (fabs@c@(in2r) >= fabs@c@(in2i)) { + const @type@ rat = in2i/in2r; + const @type@ scl = 1/(in2r + in2i*rat); + ((@type@ *)op1)[0] = npy_floor@c@((in1r + in1i*rat)*scl); + ((@type@ *)op1)[1] = 0; + } + else { + const @type@ rat = in2r/in2i; + const @type@ scl = 1/(in2i + in2r*rat); + ((@type@ *)op1)[0] = npy_floor@c@((in1r*rat + in1i)*scl); + ((@type@ *)op1)[1] = 0; + } } } diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index b3fd24bd6..212e3b597 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -11,15 +11,27 @@ class TestDivision(TestCase): assert_equal(x % 100, [5, 10, 90, 0, 95, 90, 10, 0, 80]) def test_division_complex(self): - # check that division is correct + # check that implementation is correct msg = "Complex division implementation check" - a = np.array([1. + 1.*1j, 1. + .5*1j, 1. + 2.*1j], dtype=np.complex128) - assert_almost_equal(a**2/a, a, err_msg=msg) + x = np.array([1. + 1.*1j, 1. + .5*1j, 1. + 2.*1j], dtype=np.complex128) + assert_almost_equal(x**2/x, x, err_msg=msg) # check overflow, underflow msg = "Complex division overflow/underflow check" - a = np.array([1.e+110, 1.e-110], dtype=np.complex128) - b = a**2/a - assert_almost_equal(b/a, [1, 1], err_msg=msg) + x = np.array([1.e+110, 1.e-110], dtype=np.complex128) + y = x**2/x + assert_almost_equal(y/x, [1, 1], err_msg=msg) + + def test_floor_division_complex(self): + # check that implementation is correct + msg = "Complex floor division implementation check" + x = np.array([.9 + 1j, -.1 + 1j, .9 + .5*1j, .9 + 2.*1j], dtype=np.complex128) + y = np.array([0., -1., 0., 0.], dtype=np.complex128) + assert_equal(np.floor_divide(x**2,x), y, err_msg=msg) + # check overflow, underflow + msg = "Complex floor division overflow/underflow check" + x = np.array([1.e+110, 1.e-110], dtype=np.complex128) + y = np.floor_divide(x**2, x) + assert_equal(y, [1.e+110, 0], err_msg=msg) class TestPower(TestCase): def test_power_float(self): |