diff options
author | Stefan Behnel <stefan_ml@behnel.de> | 2017-07-17 10:25:35 +0200 |
---|---|---|
committer | Stefan Behnel <stefan_ml@behnel.de> | 2017-07-17 10:25:35 +0200 |
commit | ac3ffe1de6776d225efb994ab046872757a147ea (patch) | |
tree | 238bd427a3f69362115d27694dcb48a245bad950 | |
parent | c82c4942fd557d3087ed7f21159f77303dd53e41 (diff) | |
download | cython-ac3ffe1de6776d225efb994ab046872757a147ea.tar.gz |
repair comparison between C complex types and extension types: previously generated an invalid type check
-rw-r--r-- | Cython/Compiler/ExprNodes.py | 4 | ||||
-rw-r--r-- | tests/run/complex_numbers_T305.pyx | 55 |
2 files changed, 48 insertions, 11 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 25e13790f..991319990 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -11879,9 +11879,9 @@ class CmpNode(object): error(self.pos, "complex types are unordered") new_common_type = error_type elif type1.is_pyobject: - new_common_type = type1 + new_common_type = Builtin.complex_type if type1.subtype_of(Builtin.complex_type) else py_object_type elif type2.is_pyobject: - new_common_type = type2 + new_common_type = Builtin.complex_type if type2.subtype_of(Builtin.complex_type) else py_object_type else: new_common_type = PyrexTypes.widest_numeric_type(type1, type2) elif type1.is_numeric and type2.is_numeric: diff --git a/tests/run/complex_numbers_T305.pyx b/tests/run/complex_numbers_T305.pyx index a5decf581..9d719f609 100644 --- a/tests/run/complex_numbers_T305.pyx +++ b/tests/run/complex_numbers_T305.pyx @@ -1,7 +1,33 @@ # ticket: 305 +from cpython.object cimport Py_EQ, Py_NE + cimport cython + +cdef class Complex3j: + """ + >>> Complex3j() == 3j + True + >>> Complex3j() == Complex3j() + True + >>> Complex3j() != 3j + False + >>> Complex3j() != 3 + True + >>> Complex3j() != Complex3j() + False + """ + def __richcmp__(a, b, int op): + if op == Py_EQ or op == Py_NE: + if isinstance(a, Complex3j): + eq = isinstance(b, Complex3j) or b == 3j + else: + eq = isinstance(b, Complex3j) and a == 3j + return eq if op == Py_EQ else not eq + return NotImplemented + + def test_object_conversion(o): """ >>> test_object_conversion(2) @@ -13,6 +39,7 @@ def test_object_conversion(o): cdef double complex b = o return (a, b) + def test_arithmetic(double complex z, double complex w): """ >>> test_arithmetic(2j, 4j) @@ -24,6 +51,7 @@ def test_arithmetic(double complex z, double complex w): """ return +z, -z+0, z+w, z-w, z*w, z/w + def test_div(double complex a, double complex b, expected): """ >>> big = 2.0**1023 @@ -34,6 +62,7 @@ def test_div(double complex a, double complex b, expected): if '_c99_' not in __name__: assert a / b == expected, (a / b, expected) + def test_pow(double complex z, double complex w, tol=None): """ Various implementations produce slightly different results... @@ -55,6 +84,7 @@ def test_pow(double complex z, double complex w, tol=None): else: return abs(z**w / <object>z ** <object>w - 1) < tol + def test_int_pow(double complex z, int n, tol=None): """ >>> [test_int_pow(complex(0, 1), k, 1e-15) for k in range(-4, 5)] @@ -71,6 +101,7 @@ def test_int_pow(double complex z, int n, tol=None): else: return abs(z**n / <object>z ** <object>n - 1) < tol + @cython.cdivision(False) def test_div_by_zero(double complex z): """ @@ -83,6 +114,7 @@ def test_div_by_zero(double complex z): """ return 1/z + def test_coercion(int a, float b, double c, float complex d, double complex e): """ >>> test_coercion(1, 1.5, 2.5, 4+1j, 10j) @@ -101,29 +133,34 @@ def test_coercion(int a, float b, double c, float complex d, double complex e): z = e; print z return z + a + b + c + d + e + def test_compare(double complex a, double complex b): """ >>> test_compare(3, 3) - (True, False) + (True, False, False, False, False, True) >>> test_compare(3j, 3j) - (True, False) + (True, False, True, True, True, False) >>> test_compare(3j, 4j) - (False, True) + (False, True, True, False, True, True) >>> test_compare(3, 4) - (False, True) + (False, True, False, False, False, True) """ - return a == b, a != b + return a == b, a != b, a == 3j, 3j == b, a == Complex3j(), Complex3j() != b + def test_compare_coerce(double complex a, int b): """ >>> test_compare_coerce(3, 4) - (False, True) + (False, True, False, False, False, True) >>> test_compare_coerce(4+1j, 4) - (False, True) + (False, True, False, True, False, True) >>> test_compare_coerce(4, 4) - (True, False) + (True, False, False, False, False, True) + >>> test_compare_coerce(3j, 4) + (False, True, True, False, True, False) """ - return a == b, a != b + return a == b, a != b, a == 3j, 4+1j == a, a == Complex3j(), Complex3j() != a + def test_literal(): """ |