diff options
author | da-woods <dw-git@d-woods.co.uk> | 2023-04-13 08:41:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-13 09:41:33 +0200 |
commit | 0dd38bf8acbc0822986e7f4f196321e661abed53 (patch) | |
tree | e95ce7244f29611e658970eaf2829da21b315f74 | |
parent | 5f050d675ad6f63cf330ede553a97f03439d2c26 (diff) | |
download | cython-0dd38bf8acbc0822986e7f4f196321e661abed53.tar.gz |
Fix issues with partially optimized cascaded comparisons (GH-5357)
If a cascaded comparison is partially optimized (i.e. only some
of the comparisons are optimized) then the result types must end up
consistent all the way through. At the moment we select PyObject
which probably isn't the most efficient option, but is the
easiest to implement
We do not require the whole cascaded optimization to succeed.
Instead, we can just get Python comparisons as bool, and just ensure
that the entire cascade has the same type
-rw-r--r-- | Cython/Compiler/ExprNodes.py | 18 | ||||
-rw-r--r-- | tests/run/cascmp.pyx | 66 |
2 files changed, 78 insertions, 6 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 249a40f46..6097297b8 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -13093,7 +13093,7 @@ class CmpNode(object): "Eq" if self.operator == "==" else "Ne", self, PyrexTypes.c_bint_type, - self.operand1, + operand1, self.operand2 ) if result: @@ -13168,8 +13168,10 @@ class CmpNode(object): elif operand1.type.is_pyobject and op not in ('is', 'is_not'): assert op not in ('in', 'not_in'), op - code.putln("%s = PyObject_RichCompare(%s, %s, %s); %s%s" % ( + assert self.type.is_pyobject or self.type is PyrexTypes.c_bint_type + code.putln("%s = PyObject_RichCompare%s(%s, %s, %s); %s%s" % ( result_code, + "" if self.type.is_pyobject else "Bool", operand1.py_result(), operand2.py_result(), richcmp_constants[op], @@ -13265,6 +13267,12 @@ class PrimaryCmpNode(ExprNode, CmpNode): operand1 = self.operand1.compile_time_value(denv) return self.cascaded_compile_time_value(operand1, denv) + def unify_cascade_type(self): + cdr = self.cascade + while cdr: + cdr.type = self.type + cdr = cdr.cascade + def analyse_types(self, env): self.operand1 = self.operand1.analyse_types(env) self.operand2 = self.operand2.analyse_types(env) @@ -13343,10 +13351,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): self.type = PyrexTypes.py_object_type else: self.type = PyrexTypes.c_bint_type - cdr = self.cascade - while cdr: - cdr.type = self.type - cdr = cdr.cascade + self.unify_cascade_type() if self.is_pycmp or self.cascade or self.special_bool_cmp_function: # 1) owned reference, 2) reused value, 3) potential function error return value self.is_temp = 1 @@ -13405,6 +13410,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): self.operand2, env, result_is_bool=True) if operand2 is not self.operand2: self.coerced_operand2 = operand2 + self.unify_cascade_type() return self # TODO: check if we can optimise parts of the cascade here return ExprNode.coerce_to_boolean(self, env) diff --git a/tests/run/cascmp.pyx b/tests/run/cascmp.pyx index becae3fc5..600cf3d85 100644 --- a/tests/run/cascmp.pyx +++ b/tests/run/cascmp.pyx @@ -36,3 +36,69 @@ def const_cascade(x): 1 <= 1 <= x <= 2 <= 3 > x <= 2 <= 2, 1 <= 1 <= x <= 1 <= 1 <= x <= 2, ) + +def eq_if_statement(a, b, c): + """ + >>> eq_if_statement(1, 2, 3) + False + >>> eq_if_statement(2, 3, 4) + False + >>> eq_if_statement(1, 1, 2) + False + >>> eq_if_statement(1, "not an int", 2) + False + >>> eq_if_statement(2, 1, 1) + False + >>> eq_if_statement(1, 1, 1) + True + """ + if 1 == a == b == c: + return True + else: + return False + +def eq_if_statement_semi_optimized(a, int b, int c): + """ + Some but not all of the cascade ends up optimized + (probably not as much as should be). The test is mostly + that it keeps the types consistent throughout + + >>> eq_if_statement_semi_optimized(1, 2, 3) + False + >>> eq_if_statement_semi_optimized(2, 3, 4) + False + >>> eq_if_statement_semi_optimized(1, 1, 2) + False + >>> eq_if_statement_semi_optimized("not an int", 1, 2) + False + >>> eq_if_statement_semi_optimized(2, 1, 1) + False + >>> eq_if_statement_semi_optimized(1, 1, 1) + True + """ + if 1 == a == b == c == 1: + return True + else: + return False + +def eq_if_statement_semi_optimized2(a, b, c): + """ + Here only "b==c" fails to optimize + + >>> eq_if_statement_semi_optimized2(1, 2, 3) + False + >>> eq_if_statement_semi_optimized2(2, 3, 4) + False + >>> eq_if_statement_semi_optimized2(1, 1, 2) + False + >>> eq_if_statement_semi_optimized2(1, "not an int", 2) + False + >>> eq_if_statement_semi_optimized2(2, 1, 1) + False + >>> eq_if_statement_semi_optimized2(1, 1, 1) + True + """ + if 1 == a == 1 == b == c: + return True + else: + return False |