summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/loops.c.src15
-rw-r--r--numpy/core/tests/test_umath.py24
2 files changed, 30 insertions, 9 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 3b1b70a68..1a29e6904 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1131,9 +1131,18 @@ C@TYPE@_floor_divide(char **args, intp *dimensions, intp *steps, void *NPY_UNUSE
const @type@ in1i = ((@type@ *)ip1)[1];
const @type@ in2r = ((@type@ *)ip2)[0];
const @type@ in2i = ((@type@ *)ip2)[1];
- @type@ d = in2r*in2r + in2i*in2i;
- ((@type@ *)op1)[0] = npy_floor@c@((in1r*in2r + in1i*in2i)/d);
- ((@type@ *)op1)[1] = 0;
+ if (fabs@c@(in2r) >= fabs@c@(in2i)) {
+ const @type@ rat = in2i/in2r;
+ const @type@ scl = 1/(in2r + in2i*rat);
+ ((@type@ *)op1)[0] = npy_floor@c@((in1r + in1i*rat)*scl);
+ ((@type@ *)op1)[1] = 0;
+ }
+ else {
+ const @type@ rat = in2r/in2i;
+ const @type@ scl = 1/(in2i + in2r*rat);
+ ((@type@ *)op1)[0] = npy_floor@c@((in1r*rat + in1i)*scl);
+ ((@type@ *)op1)[1] = 0;
+ }
}
}
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index b3fd24bd6..212e3b597 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -11,15 +11,27 @@ class TestDivision(TestCase):
assert_equal(x % 100, [5, 10, 90, 0, 95, 90, 10, 0, 80])
def test_division_complex(self):
- # check that division is correct
+ # check that implementation is correct
msg = "Complex division implementation check"
- a = np.array([1. + 1.*1j, 1. + .5*1j, 1. + 2.*1j], dtype=np.complex128)
- assert_almost_equal(a**2/a, a, err_msg=msg)
+ x = np.array([1. + 1.*1j, 1. + .5*1j, 1. + 2.*1j], dtype=np.complex128)
+ assert_almost_equal(x**2/x, x, err_msg=msg)
# check overflow, underflow
msg = "Complex division overflow/underflow check"
- a = np.array([1.e+110, 1.e-110], dtype=np.complex128)
- b = a**2/a
- assert_almost_equal(b/a, [1, 1], err_msg=msg)
+ x = np.array([1.e+110, 1.e-110], dtype=np.complex128)
+ y = x**2/x
+ assert_almost_equal(y/x, [1, 1], err_msg=msg)
+
+ def test_floor_division_complex(self):
+ # check that implementation is correct
+ msg = "Complex floor division implementation check"
+ x = np.array([.9 + 1j, -.1 + 1j, .9 + .5*1j, .9 + 2.*1j], dtype=np.complex128)
+ y = np.array([0., -1., 0., 0.], dtype=np.complex128)
+ assert_equal(np.floor_divide(x**2,x), y, err_msg=msg)
+ # check overflow, underflow
+ msg = "Complex floor division overflow/underflow check"
+ x = np.array([1.e+110, 1.e-110], dtype=np.complex128)
+ y = np.floor_divide(x**2, x)
+ assert_equal(y, [1.e+110, 0], err_msg=msg)
class TestPower(TestCase):
def test_power_float(self):