summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Behnel <stefan_ml@behnel.de>2012-11-09 22:52:09 +0100
committerStefan Behnel <stefan_ml@behnel.de>2012-11-09 22:52:09 +0100
commit59ee0280d0c5f8bb7970dc2dc8ef14e7acf00428 (patch)
treefef39c10b9e319bc1ab5f6aeac6ee1ef392480ec
parentbb18427cdd9e4796f48457b04d1365fbfabb2ad1 (diff)
downloadcython-59ee0280d0c5f8bb7970dc2dc8ef14e7acf00428.tar.gz
fix type coercion in cascaded comparisons
--HG-- extra : transplant_source : %06m%CA%82%F2%FDq%E4%88%DDJ%C8%FDe%00KG%D55%CB
-rwxr-xr-xCython/Compiler/ExprNodes.py31
-rw-r--r--tests/run/inop.pyx26
2 files changed, 48 insertions, 9 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index 2ef54e8be..59bbd5891 100755
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -8850,9 +8850,10 @@ class PrimaryCmpNode(ExprNode, CmpNode):
# Instead, we override all the framework methods
# which use it.
- child_attrs = ['operand1', 'operand2', 'cascade']
+ child_attrs = ['operand1', 'operand2', 'coerced_operand2', 'cascade']
cascade = None
+ coerced_operand2 = None
is_memslice_nonecheck = False
def infer_type(self, env):
@@ -8930,9 +8931,11 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.coerce_operands_to(common_type, env)
if self.cascade:
- self.operand2 = self.cascade.optimise_comparison(
- self.operand2.coerce_to_simple(env), env)
+ self.operand2 = self.operand2.coerce_to_simple(env)
self.cascade.coerce_cascaded_operands_to_temp(env)
+ operand2 = self.cascade.optimise_comparison(self.operand2, env)
+ if operand2 is not self.operand2:
+ self.coerced_operand2 = operand2
if self.is_python_result():
self.type = PyrexTypes.py_object_type
else:
@@ -9036,8 +9039,9 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.generate_operation_code(code, self.result(),
self.operand1, self.operator, self.operand2)
if self.cascade:
- self.cascade.generate_evaluation_code(code,
- self.result(), self.operand2)
+ self.cascade.generate_evaluation_code(
+ code, self.result(), self.coerced_operand2 or self.operand2,
+ needs_evaluation=self.coerced_operand2 is not None)
self.operand1.generate_disposal_code(code)
self.operand1.free_temps(code)
self.operand2.generate_disposal_code(code)
@@ -9072,9 +9076,10 @@ class CascadedCmpNode(Node, CmpNode):
# operand2 ExprNode
# cascade CascadedCmpNode
- child_attrs = ['operand2', 'cascade']
+ child_attrs = ['operand2', 'coerced_operand2', 'cascade']
cascade = None
+ coerced_operand2 = None
constant_result = constant_value_not_set # FIXME: where to calculate this?
def infer_type(self, env):
@@ -9101,7 +9106,9 @@ class CascadedCmpNode(Node, CmpNode):
if not operand1.type.is_pyobject:
operand1 = operand1.coerce_to_pyobject(env)
if self.cascade:
- self.operand2 = self.cascade.optimise_comparison(self.operand2, env)
+ operand2 = self.cascade.optimise_comparison(self.operand2, env)
+ if operand2 is not self.operand2:
+ self.coerced_operand2 = operand2
return operand1
def coerce_operands_to_pyobjects(self, env):
@@ -9117,18 +9124,24 @@ class CascadedCmpNode(Node, CmpNode):
self.operand2 = self.operand2.coerce_to_simple(env)
self.cascade.coerce_cascaded_operands_to_temp(env)
- def generate_evaluation_code(self, code, result, operand1):
+ def generate_evaluation_code(self, code, result, operand1, needs_evaluation=False):
if self.type.is_pyobject:
code.putln("if (__Pyx_PyObject_IsTrue(%s)) {" % result)
code.put_decref(result, self.type)
else:
code.putln("if (%s) {" % result)
+ if needs_evaluation:
+ operand1.generate_evaluation_code(code)
self.operand2.generate_evaluation_code(code)
self.generate_operation_code(code, result,
operand1, self.operator, self.operand2)
if self.cascade:
self.cascade.generate_evaluation_code(
- code, result, self.operand2)
+ code, result, self.coerced_operand2 or self.operand2,
+ needs_evaluation=self.coerced_operand2 is not None)
+ if needs_evaluation:
+ operand1.generate_disposal_code(code)
+ operand1.free_temps(code)
# Cascaded cmp result is always temp
self.operand2.generate_disposal_code(code)
self.operand2.free_temps(code)
diff --git a/tests/run/inop.pyx b/tests/run/inop.pyx
index da3ac799d..481c60c9d 100644
--- a/tests/run/inop.pyx
+++ b/tests/run/inop.pyx
@@ -376,3 +376,29 @@ def test_inop_cascaded(x):
False
"""
return 1 != x in [2]
+
+def test_inop_cascaded_one():
+ """
+ >>> test_inop_cascaded_one()
+ False
+ """
+ # copied from CPython's test_grammar.py
+ return 1 < 1 > 1 == 1 >= 1 <= 1 != 1 in 1 not in 1 is 1 is not 1
+
+def test_inop_cascaded_int_orig(int x):
+ """
+ >>> test_inop_cascaded_int_orig(1)
+ False
+ """
+ return 1 < 1 > 1 == 1 >= 1 <= 1 != x in 1 not in 1 is 1 is not 1
+
+def test_inop_cascaded_int(int x):
+ """
+ >>> test_inop_cascaded_int(1)
+ False
+ >>> test_inop_cascaded_int(2)
+ True
+ >>> test_inop_cascaded_int(3)
+ False
+ """
+ return 1 != x in [1,2]