diff options
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 46 |
1 files changed, 34 insertions, 12 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 28b1c6ddd..d2e644ce2 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1591,7 +1591,9 @@ def _interpret_as_from(element): def _const_expr(element): - if element is None: + if isinstance(element, (Null, False_, True_)): + return element + elif element is None: return null() elif element is False: return false() @@ -2011,18 +2013,33 @@ class _DefaultColumnComparator(operators.ColumnOperators): return op, other_comparator.type def _boolean_compare(self, expr, op, obj, negate=None, reverse=False, - **kwargs - ): - if obj is None or isinstance(obj, Null): - if op in (operators.eq, operators.is_): - return BinaryExpression(expr, null(), operators.is_, - negate=operators.isnot) - elif op in (operators.ne, operators.isnot): - return BinaryExpression(expr, null(), operators.isnot, - negate=operators.is_) + **kwargs): + if isinstance(obj, (util.NoneType, bool, Null, True_, False_)): + + # allow x ==/!= True/False to be treated as a literal. + # this comes out to "== / != true/false" or "1/0" if those + # constants aren't supported and works on all platforms + if op in (operators.eq, operators.ne) and \ + isinstance(obj, (bool, True_, False_)): + return BinaryExpression(expr, + obj, + op, + type_=sqltypes.BOOLEANTYPE, + negate=negate, modifiers=kwargs) else: - raise exc.ArgumentError("Only '='/'!=' operators can " - "be used with NULL") + # all other None/True/False uses IS, IS NOT + if op in (operators.eq, operators.is_): + return BinaryExpression(expr, _const_expr(obj), + operators.is_, + negate=operators.isnot) + elif op in (operators.ne, operators.isnot): + return BinaryExpression(expr, _const_expr(obj), + operators.isnot, + negate=operators.is_) + else: + raise exc.ArgumentError( + "Only '=', '!=', 'is_()', 'isnot()' operators can " + "be used with None/True/False") else: obj = self._check_literal(expr, op, obj) @@ -3253,6 +3270,8 @@ class False_(ColumnElement): def __init__(self): self.type = sqltypes.BOOLEANTYPE + def compare(self, other): + return isinstance(other, False_) class True_(ColumnElement): """Represent the ``true`` keyword in a SQL statement. @@ -3266,6 +3285,9 @@ class True_(ColumnElement): def __init__(self): self.type = sqltypes.BOOLEANTYPE + def compare(self, other): + return isinstance(other, True_) + class ClauseList(ClauseElement): """Describe a list of clauses, separated by an operator. |