summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Behnel <stefan_ml@behnel.de>2017-07-17 10:25:35 +0200
committerStefan Behnel <stefan_ml@behnel.de>2017-07-17 10:25:35 +0200
commitac3ffe1de6776d225efb994ab046872757a147ea (patch)
tree238bd427a3f69362115d27694dcb48a245bad950
parentc82c4942fd557d3087ed7f21159f77303dd53e41 (diff)
downloadcython-ac3ffe1de6776d225efb994ab046872757a147ea.tar.gz
repair comparison between C complex types and extension types: previously generated an invalid type check
-rw-r--r--Cython/Compiler/ExprNodes.py4
-rw-r--r--tests/run/complex_numbers_T305.pyx55
2 files changed, 48 insertions, 11 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index 25e13790f..991319990 100644
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -11879,9 +11879,9 @@ class CmpNode(object):
error(self.pos, "complex types are unordered")
new_common_type = error_type
elif type1.is_pyobject:
- new_common_type = type1
+ new_common_type = Builtin.complex_type if type1.subtype_of(Builtin.complex_type) else py_object_type
elif type2.is_pyobject:
- new_common_type = type2
+ new_common_type = Builtin.complex_type if type2.subtype_of(Builtin.complex_type) else py_object_type
else:
new_common_type = PyrexTypes.widest_numeric_type(type1, type2)
elif type1.is_numeric and type2.is_numeric:
diff --git a/tests/run/complex_numbers_T305.pyx b/tests/run/complex_numbers_T305.pyx
index a5decf581..9d719f609 100644
--- a/tests/run/complex_numbers_T305.pyx
+++ b/tests/run/complex_numbers_T305.pyx
@@ -1,7 +1,33 @@
# ticket: 305
+from cpython.object cimport Py_EQ, Py_NE
+
cimport cython
+
+cdef class Complex3j:
+ """
+ >>> Complex3j() == 3j
+ True
+ >>> Complex3j() == Complex3j()
+ True
+ >>> Complex3j() != 3j
+ False
+ >>> Complex3j() != 3
+ True
+ >>> Complex3j() != Complex3j()
+ False
+ """
+ def __richcmp__(a, b, int op):
+ if op == Py_EQ or op == Py_NE:
+ if isinstance(a, Complex3j):
+ eq = isinstance(b, Complex3j) or b == 3j
+ else:
+ eq = isinstance(b, Complex3j) and a == 3j
+ return eq if op == Py_EQ else not eq
+ return NotImplemented
+
+
def test_object_conversion(o):
"""
>>> test_object_conversion(2)
@@ -13,6 +39,7 @@ def test_object_conversion(o):
cdef double complex b = o
return (a, b)
+
def test_arithmetic(double complex z, double complex w):
"""
>>> test_arithmetic(2j, 4j)
@@ -24,6 +51,7 @@ def test_arithmetic(double complex z, double complex w):
"""
return +z, -z+0, z+w, z-w, z*w, z/w
+
def test_div(double complex a, double complex b, expected):
"""
>>> big = 2.0**1023
@@ -34,6 +62,7 @@ def test_div(double complex a, double complex b, expected):
if '_c99_' not in __name__:
assert a / b == expected, (a / b, expected)
+
def test_pow(double complex z, double complex w, tol=None):
"""
Various implementations produce slightly different results...
@@ -55,6 +84,7 @@ def test_pow(double complex z, double complex w, tol=None):
else:
return abs(z**w / <object>z ** <object>w - 1) < tol
+
def test_int_pow(double complex z, int n, tol=None):
"""
>>> [test_int_pow(complex(0, 1), k, 1e-15) for k in range(-4, 5)]
@@ -71,6 +101,7 @@ def test_int_pow(double complex z, int n, tol=None):
else:
return abs(z**n / <object>z ** <object>n - 1) < tol
+
@cython.cdivision(False)
def test_div_by_zero(double complex z):
"""
@@ -83,6 +114,7 @@ def test_div_by_zero(double complex z):
"""
return 1/z
+
def test_coercion(int a, float b, double c, float complex d, double complex e):
"""
>>> test_coercion(1, 1.5, 2.5, 4+1j, 10j)
@@ -101,29 +133,34 @@ def test_coercion(int a, float b, double c, float complex d, double complex e):
z = e; print z
return z + a + b + c + d + e
+
def test_compare(double complex a, double complex b):
"""
>>> test_compare(3, 3)
- (True, False)
+ (True, False, False, False, False, True)
>>> test_compare(3j, 3j)
- (True, False)
+ (True, False, True, True, True, False)
>>> test_compare(3j, 4j)
- (False, True)
+ (False, True, True, False, True, True)
>>> test_compare(3, 4)
- (False, True)
+ (False, True, False, False, False, True)
"""
- return a == b, a != b
+ return a == b, a != b, a == 3j, 3j == b, a == Complex3j(), Complex3j() != b
+
def test_compare_coerce(double complex a, int b):
"""
>>> test_compare_coerce(3, 4)
- (False, True)
+ (False, True, False, False, False, True)
>>> test_compare_coerce(4+1j, 4)
- (False, True)
+ (False, True, False, True, False, True)
>>> test_compare_coerce(4, 4)
- (True, False)
+ (True, False, False, False, False, True)
+ >>> test_compare_coerce(3j, 4)
+ (False, True, True, False, True, False)
"""
- return a == b, a != b
+ return a == b, a != b, a == 3j, 4+1j == a, a == Complex3j(), Complex3j() != a
+
def test_literal():
"""