summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2011-04-02 19:54:04 +0200
committerPauli Virtanen <pav@iki.fi>2011-04-02 20:04:13 +0200
commit65b77ee94131bf8365d8a6dba6fa19da1269339c (patch)
tree460fc0f11f3bdc6c3fa8300c55ece47d2c7eadf3 /numpy/core
parentb8101c94e26ee21d3bdc49270efe14924ca08078 (diff)
downloadnumpy-65b77ee94131bf8365d8a6dba6fa19da1269339c.tar.gz
BUG: core: make complex division by zero to yield inf properly (#1776)
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/scalarmathmodule.c.src15
-rw-r--r--numpy/core/src/umath/loops.c.src19
-rw-r--r--numpy/core/tests/test_scalarmath.py22
-rw-r--r--numpy/core/tests/test_umath.py17
4 files changed, 64 insertions, 9 deletions
diff --git a/numpy/core/src/scalarmathmodule.c.src b/numpy/core/src/scalarmathmodule.c.src
index 2bcc516b1..56f1bc238 100644
--- a/numpy/core/src/scalarmathmodule.c.src
+++ b/numpy/core/src/scalarmathmodule.c.src
@@ -382,10 +382,17 @@ static npy_half (*_basic_half_fmod)(npy_half, npy_half);
(outp)->real = (a).real * (b).real - (a).imag * (b).imag; \
(outp)->imag = (a).real * (b).imag + (a).imag * (b).real; \
} while(0)
-#define @name@_ctype_divide(a, b, outp) do{ \
- @rtype@ d = (b).real*(b).real + (b).imag*(b).imag; \
- (outp)->real = ((a).real*(b).real + (a).imag*(b).imag)/d; \
- (outp)->imag = ((a).imag*(b).real - (a).real*(b).imag)/d; \
+/* Note: complex division by zero must yield some complex inf */
+#define @name@_ctype_divide(a, b, outp) do{ \
+ @rtype@ d = (b).real*(b).real + (b).imag*(b).imag; \
+ if (d != 0) { \
+ (outp)->real = ((a).real*(b).real + (a).imag*(b).imag)/d; \
+ (outp)->imag = ((a).imag*(b).real - (a).real*(b).imag)/d; \
+ } \
+ else { \
+ (outp)->real = (a).real/d; \
+ (outp)->imag = (a).imag/d; \
+ } \
} while(0)
#define @name@_ctype_true_divide @name@_ctype_divide
#define @name@_ctype_floor_divide(a, b, outp) do { \
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 5212207da..54e5ac984 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1804,11 +1804,20 @@ C@TYPE@_divide(char **args, intp *dimensions, intp *steps, void *NPY_UNUSED(func
const @type@ in1i = ((@type@ *)ip1)[1];
const @type@ in2r = ((@type@ *)ip2)[0];
const @type@ in2i = ((@type@ *)ip2)[1];
- if (npy_fabs@c@(in2r) >= npy_fabs@c@(in2i)) {
- const @type@ rat = in2i/in2r;
- const @type@ scl = 1.0@c@/(in2r + in2i*rat);
- ((@type@ *)op1)[0] = (in1r + in1i*rat)*scl;
- ((@type@ *)op1)[1] = (in1i - in1r*rat)*scl;
+ const @type@ in2r_abs = npy_fabs@c@(in2r);
+ const @type@ in2i_abs = npy_fabs@c@(in2i);
+ if (in2r_abs >= in2i_abs) {
+ if (in2r_abs == 0 && in2i_abs == 0) {
+ /* divide by zero should yield a complex inf or nan */
+ ((@type@ *)op1)[0] = in1r/in2r_abs;
+ ((@type@ *)op1)[1] = in1i/in2i_abs;
+ }
+ else {
+ const @type@ rat = in2i/in2r;
+ const @type@ scl = 1.0@c@/(in2r + in2i*rat);
+ ((@type@ *)op1)[0] = (in1r + in1i*rat)*scl;
+ ((@type@ *)op1)[1] = (in1i - in1r*rat)*scl;
+ }
}
else {
const @type@ rat = in2r/in2i;
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index a2b3a232b..a35a9c542 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -60,6 +60,28 @@ class TestPower(TestCase):
assert_almost_equal(b, 6765201, err_msg=msg)
+class TestComplexDivision(TestCase):
+ def test_zero_division(self):
+ err = np.seterr(over="ignore")
+ try:
+ for t in [np.complex64, np.complex128]:
+ a = t(0.0)
+ b = t(1.0)
+ assert_(np.isinf(b/a))
+ b = t(complex(np.inf, np.inf))
+ assert_(np.isinf(b/a))
+ b = t(complex(np.inf, np.nan))
+ assert_(np.isinf(b/a))
+ b = t(complex(np.nan, np.inf))
+ assert_(np.isinf(b/a))
+ b = t(complex(np.nan, np.nan))
+ assert_(np.isnan(b/a))
+ b = t(0.)
+ assert_(np.isnan(b/a))
+ finally:
+ np.seterr(**err)
+
+
class TestConversion(TestCase):
def test_int_from_long(self):
l = [1e6, 1e12, 1e18, -1e6, -1e12, -1e18]
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index b5f9d5745..5f5d1257e 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -28,6 +28,23 @@ class TestDivision(TestCase):
y = x**2/x
assert_almost_equal(y/x, [1, 1], err_msg=msg)
+ def test_zero_division_complex(self):
+ err = np.seterr(invalid="ignore")
+ try:
+ x = np.array([0.0], dtype=np.complex128)
+ y = 1.0/x
+ assert_(np.isinf(y)[0])
+ y = complex(np.inf, np.nan)/x
+ assert_(np.isinf(y)[0])
+ y = complex(np.nan, np.inf)/x
+ assert_(np.isinf(y)[0])
+ y = complex(np.inf, np.inf)/x
+ assert_(np.isinf(y)[0])
+ y = 0.0/x
+ assert_(np.isnan(y)[0])
+ finally:
+ np.seterr(**err)
+
def test_floor_division_complex(self):
# check that implementation is correct
msg = "Complex floor division implementation check"