diff options
Diffstat (limited to 'lib/sqlalchemy/sql/default_comparator.py')
-rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 146 |
1 files changed, 40 insertions, 106 deletions
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 9a12b84cd..918f7524e 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -8,32 +8,21 @@ """Default implementation of SQL comparison operations. """ + +from . import coercions from . import operators +from . import roles from . import type_api -from .elements import _clause_element_as_expr -from .elements import _const_expr -from .elements import _is_literal -from .elements import _literal_as_text from .elements import and_ from .elements import BinaryExpression -from .elements import BindParameter -from .elements import ClauseElement from .elements import ClauseList from .elements import collate from .elements import CollectionAggregate -from .elements import ColumnElement from .elements import False_ from .elements import Null from .elements import or_ -from .elements import TextClause from .elements import True_ -from .elements import Tuple from .elements import UnaryExpression -from .elements import Visitable -from .selectable import Alias -from .selectable import ScalarSelect -from .selectable import Selectable -from .selectable import SelectBase from .. import exc from .. import util @@ -62,7 +51,7 @@ def _boolean_compare( ): return BinaryExpression( expr, - _literal_as_text(obj), + coercions.expect(roles.ConstExprRole, obj), op, type_=result_type, negate=negate, @@ -71,7 +60,7 @@ def _boolean_compare( elif op in (operators.is_distinct_from, operators.isnot_distinct_from): return BinaryExpression( expr, - _literal_as_text(obj), + coercions.expect(roles.ConstExprRole, obj), op, type_=result_type, negate=negate, @@ -82,7 +71,7 @@ def _boolean_compare( if op in (operators.eq, operators.is_): return BinaryExpression( expr, - _const_expr(obj), + coercions.expect(roles.ConstExprRole, obj), operators.is_, negate=operators.isnot, type_=result_type, @@ -90,7 +79,7 @@ def _boolean_compare( elif op in (operators.ne, operators.isnot): return BinaryExpression( expr, - _const_expr(obj), + coercions.expect(roles.ConstExprRole, obj), operators.isnot, negate=operators.is_, type_=result_type, @@ -102,7 +91,9 @@ def _boolean_compare( "operators can be used with None/True/False" ) else: - obj = _check_literal(expr, op, obj) + obj = coercions.expect( + roles.BinaryElementRole, element=obj, operator=op, expr=expr + ) if reverse: return BinaryExpression( @@ -127,7 +118,9 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw): - obj = _check_literal(expr, op, obj) + obj = coercions.expect( + roles.BinaryElementRole, obj, expr=expr, operator=op + ) if reverse: left, right = obj, expr @@ -156,77 +149,22 @@ def _scalar(expr, op, fn, **kw): def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): - seq_or_selectable = _clause_element_as_expr(seq_or_selectable) - - if isinstance(seq_or_selectable, ScalarSelect): - return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op) - elif isinstance(seq_or_selectable, SelectBase): - - # TODO: if we ever want to support (x, y, z) IN (select x, - # y, z from table), we would need a multi-column version of - # as_scalar() to produce a multi- column selectable that - # does not export itself as a FROM clause - - return _boolean_compare( - expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw - ) - elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return _boolean_compare( - expr, op, seq_or_selectable, negate=negate_op, **kw - ) - elif isinstance(seq_or_selectable, ClauseElement): - if ( - isinstance(seq_or_selectable, BindParameter) - and seq_or_selectable.expanding - ): - - if isinstance(expr, Tuple): - seq_or_selectable = seq_or_selectable._with_expanding_in_types( - [elem.type for elem in expr] - ) - - return _boolean_compare( - expr, op, seq_or_selectable, negate=negate_op - ) - else: - raise exc.InvalidRequestError( - "in_() accepts" - " either a list of expressions, " - 'a selectable, or an "expanding" bound parameter: %r' - % seq_or_selectable - ) - - # Handle non selectable arguments as sequences - args = [] - for o in seq_or_selectable: - if not _is_literal(o): - if not isinstance(o, operators.ColumnOperators): - raise exc.InvalidRequestError( - "in_() accepts" - " either a list of expressions, " - 'a selectable, or an "expanding" bound parameter: %r' % o - ) - elif o is None: - o = Null() - else: - o = expr._bind_param(op, o) - args.append(o) - - if len(args) == 0: - op, negate_op = ( - (operators.empty_in_op, operators.empty_notin_op) - if op is operators.in_op - else (operators.empty_notin_op, operators.empty_in_op) - ) + seq_or_selectable = coercions.expect( + roles.InElementRole, seq_or_selectable, expr=expr, operator=op + ) + if "in_ops" in seq_or_selectable._annotations: + op, negate_op = seq_or_selectable._annotations["in_ops"] return _boolean_compare( - expr, op, ClauseList(*args).self_group(against=op), negate=negate_op + expr, op, seq_or_selectable, negate=negate_op, **kw ) def _getitem_impl(expr, op, other, **kw): if isinstance(expr.type, type_api.INDEXABLE): - other = _check_literal(expr, op, other) + other = coercions.expect( + roles.BinaryElementRole, other, expr=expr, operator=op + ) return _binary_operate(expr, op, other, **kw) else: _unsupported_impl(expr, op, other, **kw) @@ -257,7 +195,12 @@ def _match_impl(expr, op, other, **kw): return _boolean_compare( expr, operators.match_op, - _check_literal(expr, operators.match_op, other), + coercions.expect( + roles.BinaryElementRole, + other, + expr=expr, + operator=operators.match_op, + ), result_type=type_api.MATCHTYPE, negate=operators.notmatch_op if op is operators.match_op @@ -278,8 +221,18 @@ def _between_impl(expr, op, cleft, cright, **kw): return BinaryExpression( expr, ClauseList( - _check_literal(expr, operators.and_, cleft), - _check_literal(expr, operators.and_, cright), + coercions.expect( + roles.BinaryElementRole, + cleft, + expr=expr, + operator=operators.and_, + ), + coercions.expect( + roles.BinaryElementRole, + cright, + expr=expr, + operator=operators.and_, + ), operator=operators.and_, group=False, group_contents=False, @@ -349,22 +302,3 @@ operator_lookup = { "rshift": (_unsupported_impl,), "contains": (_unsupported_impl,), } - - -def _check_literal(expr, operator, other, bindparam_type=None): - if isinstance(other, (ColumnElement, TextClause)): - if isinstance(other, BindParameter) and other.type._isnull: - other = other._clone() - other.type = expr.type - return other - elif hasattr(other, "__clause_element__"): - other = other.__clause_element__() - elif isinstance(other, type_api.TypeEngine.Comparator): - other = other.expr - - if isinstance(other, (SelectBase, Alias)): - return other.as_scalar() - elif not isinstance(other, Visitable): - return expr._bind_param(operator, other, type_=bindparam_type) - else: - return other |