summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorda-woods <dw-git@d-woods.co.uk>2023-04-13 08:41:33 +0100
committerGitHub <noreply@github.com>2023-04-13 09:41:33 +0200
commit0dd38bf8acbc0822986e7f4f196321e661abed53 (patch)
treee95ce7244f29611e658970eaf2829da21b315f74
parent5f050d675ad6f63cf330ede553a97f03439d2c26 (diff)
downloadcython-0dd38bf8acbc0822986e7f4f196321e661abed53.tar.gz
Fix issues with partially optimized cascaded comparisons (GH-5357)
If a cascaded comparison is partially optimized (i.e. only some of the comparisons are optimized) then the result types must end up consistent all the way through. At the moment we select PyObject which probably isn't the most efficient option, but is the easiest to implement We do not require the whole cascaded optimization to succeed. Instead, we can just get Python comparisons as bool, and just ensure that the entire cascade has the same type
-rw-r--r--Cython/Compiler/ExprNodes.py18
-rw-r--r--tests/run/cascmp.pyx66
2 files changed, 78 insertions, 6 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index 249a40f46..6097297b8 100644
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -13093,7 +13093,7 @@ class CmpNode(object):
"Eq" if self.operator == "==" else "Ne",
self,
PyrexTypes.c_bint_type,
- self.operand1,
+ operand1,
self.operand2
)
if result:
@@ -13168,8 +13168,10 @@ class CmpNode(object):
elif operand1.type.is_pyobject and op not in ('is', 'is_not'):
assert op not in ('in', 'not_in'), op
- code.putln("%s = PyObject_RichCompare(%s, %s, %s); %s%s" % (
+ assert self.type.is_pyobject or self.type is PyrexTypes.c_bint_type
+ code.putln("%s = PyObject_RichCompare%s(%s, %s, %s); %s%s" % (
result_code,
+ "" if self.type.is_pyobject else "Bool",
operand1.py_result(),
operand2.py_result(),
richcmp_constants[op],
@@ -13265,6 +13267,12 @@ class PrimaryCmpNode(ExprNode, CmpNode):
operand1 = self.operand1.compile_time_value(denv)
return self.cascaded_compile_time_value(operand1, denv)
+ def unify_cascade_type(self):
+ cdr = self.cascade
+ while cdr:
+ cdr.type = self.type
+ cdr = cdr.cascade
+
def analyse_types(self, env):
self.operand1 = self.operand1.analyse_types(env)
self.operand2 = self.operand2.analyse_types(env)
@@ -13343,10 +13351,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.type = PyrexTypes.py_object_type
else:
self.type = PyrexTypes.c_bint_type
- cdr = self.cascade
- while cdr:
- cdr.type = self.type
- cdr = cdr.cascade
+ self.unify_cascade_type()
if self.is_pycmp or self.cascade or self.special_bool_cmp_function:
# 1) owned reference, 2) reused value, 3) potential function error return value
self.is_temp = 1
@@ -13405,6 +13410,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.operand2, env, result_is_bool=True)
if operand2 is not self.operand2:
self.coerced_operand2 = operand2
+ self.unify_cascade_type()
return self
# TODO: check if we can optimise parts of the cascade here
return ExprNode.coerce_to_boolean(self, env)
diff --git a/tests/run/cascmp.pyx b/tests/run/cascmp.pyx
index becae3fc5..600cf3d85 100644
--- a/tests/run/cascmp.pyx
+++ b/tests/run/cascmp.pyx
@@ -36,3 +36,69 @@ def const_cascade(x):
1 <= 1 <= x <= 2 <= 3 > x <= 2 <= 2,
1 <= 1 <= x <= 1 <= 1 <= x <= 2,
)
+
+def eq_if_statement(a, b, c):
+ """
+ >>> eq_if_statement(1, 2, 3)
+ False
+ >>> eq_if_statement(2, 3, 4)
+ False
+ >>> eq_if_statement(1, 1, 2)
+ False
+ >>> eq_if_statement(1, "not an int", 2)
+ False
+ >>> eq_if_statement(2, 1, 1)
+ False
+ >>> eq_if_statement(1, 1, 1)
+ True
+ """
+ if 1 == a == b == c:
+ return True
+ else:
+ return False
+
+def eq_if_statement_semi_optimized(a, int b, int c):
+ """
+ Some but not all of the cascade ends up optimized
+ (probably not as much as should be). The test is mostly
+ that it keeps the types consistent throughout
+
+ >>> eq_if_statement_semi_optimized(1, 2, 3)
+ False
+ >>> eq_if_statement_semi_optimized(2, 3, 4)
+ False
+ >>> eq_if_statement_semi_optimized(1, 1, 2)
+ False
+ >>> eq_if_statement_semi_optimized("not an int", 1, 2)
+ False
+ >>> eq_if_statement_semi_optimized(2, 1, 1)
+ False
+ >>> eq_if_statement_semi_optimized(1, 1, 1)
+ True
+ """
+ if 1 == a == b == c == 1:
+ return True
+ else:
+ return False
+
+def eq_if_statement_semi_optimized2(a, b, c):
+ """
+ Here only "b==c" fails to optimize
+
+ >>> eq_if_statement_semi_optimized2(1, 2, 3)
+ False
+ >>> eq_if_statement_semi_optimized2(2, 3, 4)
+ False
+ >>> eq_if_statement_semi_optimized2(1, 1, 2)
+ False
+ >>> eq_if_statement_semi_optimized2(1, "not an int", 2)
+ False
+ >>> eq_if_statement_semi_optimized2(2, 1, 1)
+ False
+ >>> eq_if_statement_semi_optimized2(1, 1, 1)
+ True
+ """
+ if 1 == a == 1 == b == c:
+ return True
+ else:
+ return False