diff options
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 295 |
1 files changed, 164 insertions, 131 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 2c1c047b5..84d7c1a29 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1773,32 +1773,6 @@ class ClauseElement(Visitable): return self - @util.deprecated('0.7', - 'Only SQL expressions which subclass ' - ':class:`.Executable` may provide the ' - ':func:`.execute` method.') - def execute(self, *multiparams, **params): - """Compile and execute this :class:`.ClauseElement`. - - """ - e = self.bind - if e is None: - label = getattr(self, 'description', self.__class__.__name__) - msg = ('This %s does not support direct execution.' % label) - raise exc.UnboundExecutionError(msg) - return e._execute_clauseelement(self, multiparams, params) - - @util.deprecated('0.7', - 'Only SQL expressions which subclass ' - ':class:`.Executable` may provide the ' - ':func:`.scalar` method.') - def scalar(self, *multiparams, **params): - """Compile and execute this :class:`.ClauseElement`, returning - the result's scalar representation. - - """ - return self.execute(*multiparams, **params).scalar() - def compile(self, bind=None, dialect=None, **kw): """Compile this SQL expression. @@ -1901,7 +1875,7 @@ class Immutable(object): return self -class CompareMixin(ColumnOperators): +class _DefaultColumnComparator(ColumnOperators): """Defines comparison and math operations. The :class:`.CompareMixin` is part of the interface provided @@ -1913,91 +1887,74 @@ class CompareMixin(ColumnOperators): """ - def __compare(self, op, obj, negate=None, reverse=False, + def __compare(self, expr, op, obj, negate=None, reverse=False, **kwargs ): if obj is None or isinstance(obj, Null): if op == operators.eq: - return BinaryExpression(self, null(), operators.is_, + return BinaryExpression(expr, null(), operators.is_, negate=operators.isnot) elif op == operators.ne: - return BinaryExpression(self, null(), operators.isnot, + return BinaryExpression(expr, null(), operators.isnot, negate=operators.is_) else: raise exc.ArgumentError("Only '='/'!=' operators can " "be used with NULL") else: - obj = self._check_literal(op, obj) + obj = self._check_literal(expr, op, obj) if reverse: return BinaryExpression(obj, - self, + expr, op, type_=sqltypes.BOOLEANTYPE, negate=negate, modifiers=kwargs) else: - return BinaryExpression(self, + return BinaryExpression(expr, obj, op, type_=sqltypes.BOOLEANTYPE, negate=negate, modifiers=kwargs) - def __operate(self, op, obj, reverse=False): - obj = self._check_literal(op, obj) - + def __operate(self, expr, op, obj, reverse=False): + obj = self._check_literal(expr, op, obj) + comparator_factory = None if reverse: - left, right = obj, self + left, right = obj, expr else: - left, right = self, obj + left, right = expr, obj if left.type is None: op, result_type = sqltypes.NULLTYPE._adapt_expression(op, right.type) + result_type = sqltypes.to_instance(result_type) + if right.type._compare_type_affinity(result_type): + comparator_factory = right.comparator_factory elif right.type is None: op, result_type = left.type._adapt_expression(op, sqltypes.NULLTYPE) + result_type = sqltypes.to_instance(result_type) + if left.type._compare_type_affinity(result_type): + comparator_factory = left.comparator_factory else: - op, result_type = left.type._adapt_expression(op, - right.type) - return BinaryExpression(left, right, op, type_=result_type) + op, result_type = left.type._adapt_expression(op, right.type) + result_type = sqltypes.to_instance(result_type) + if left.type._compare_type_affinity(result_type): + comparator_factory = left.comparator_factory + elif right.type._compare_type_affinity(result_type): + comparator_factory = right.comparator_factory - # a mapping of operators with the method they use, along with their negated - # operator for comparison operators - operators = { - "add": (__operate,), - "mul": (__operate,), - "sub": (__operate,), - "div": (__operate,), - "mod": (__operate,), - "truediv": (__operate,), - "custom_op": (__operate,), - "lt": (__compare, operators.ge), - "le": (__compare, operators.gt), - "ne": (__compare, operators.eq), - "gt": (__compare, operators.le), - "ge": (__compare, operators.lt), - "eq": (__compare, operators.ne), - "like_op": (__compare, operators.notlike_op), - "ilike_op": (__compare, operators.notilike_op), - } + return BinaryExpression(left, right, op, type_=result_type, + comparator_factory=comparator_factory) - def operate(self, op, *other, **kwargs): - o = CompareMixin.operators[op.__name__] - return o[0](self, op, other[0], *o[1:], **kwargs) + def __scalar(self, expr, op, fn, **kw): + return fn(expr) - def reverse_operate(self, op, other, **kwargs): - o = CompareMixin.operators[op.__name__] - return o[0](self, op, other, reverse=True, *o[1:], **kwargs) - - def in_(self, other): - """See :meth:`.ColumnOperators.in_`.""" - return self._in_impl(operators.in_op, operators.notin_op, other) - - def _in_impl(self, op, negate_op, seq_or_selectable): + def _in_impl(self, expr, op, seq_or_selectable, negate_op): seq_or_selectable = _clause_element_as_expr(seq_or_selectable) if isinstance(seq_or_selectable, ScalarSelect): - return self.__compare(op, seq_or_selectable, + return self.__compare(expr, op, seq_or_selectable, negate=negate_op) elif isinstance(seq_or_selectable, SelectBase): @@ -2006,10 +1963,10 @@ class CompareMixin(ColumnOperators): # as_scalar() to produce a multi- column selectable that # does not export itself as a FROM clause - return self.__compare(op, seq_or_selectable.as_scalar(), + return self.__compare(expr, op, seq_or_selectable.as_scalar(), negate=negate_op) elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return self.__compare(op, seq_or_selectable, + return self.__compare(expr, op, seq_or_selectable, negate=negate_op) @@ -2018,12 +1975,12 @@ class CompareMixin(ColumnOperators): args = [] for o in seq_or_selectable: if not _is_literal(o): - if not isinstance(o, CompareMixin): + if not isinstance(o, ColumnOperators): raise exc.InvalidRequestError('in() function accept' 's either a list of non-selectable values, ' 'or a selectable: %r' % o) else: - o = self._bind_param(op, o) + o = expr._bind_param(op, o) args.append(o) if len(args) == 0: @@ -2037,18 +1994,17 @@ class CompareMixin(ColumnOperators): 'empty sequence. This results in a ' 'contradiction, which nonetheless can be ' 'expensive to evaluate. Consider alternative ' - 'strategies for improved performance.' % self) - return self != self + 'strategies for improved performance.' % expr) + return expr != expr - return self.__compare(op, + return self.__compare(expr, op, ClauseList(*args).self_group(against=op), negate=negate_op) - - def __neg__(self): + def _neg_impl(self): """See :meth:`.ColumnOperators.__neg__`.""" - return UnaryExpression(self, operator=operators.neg) + return UnaryExpression(self.expr, operator=operators.neg) - def startswith(self, other, escape=None): + def _startswith_impl(self, other, escape=None): """See :meth:`.ColumnOperators.startswith`.""" # use __radd__ to force string concat behavior return self.__compare( @@ -2058,7 +2014,7 @@ class CompareMixin(ColumnOperators): ), escape=escape) - def endswith(self, other, escape=None): + def _endswith_impl(self, other, escape=None): """See :meth:`.ColumnOperators.endswith`.""" return self.__compare( operators.like_op, @@ -2066,7 +2022,7 @@ class CompareMixin(ColumnOperators): self._check_literal(operators.like_op, other), escape=escape) - def contains(self, other, escape=None): + def _contains_impl(self, other, escape=None): """See :meth:`.ColumnOperators.contains`.""" return self.__compare( operators.like_op, @@ -2075,44 +2031,18 @@ class CompareMixin(ColumnOperators): literal_column("'%'", type_=sqltypes.String), escape=escape) - def match(self, other): + def _match_impl(self, other): """See :meth:`.ColumnOperators.match`.""" return self.__compare(operators.match_op, self._check_literal(operators.match_op, other)) - def label(self, name): - """Produce a column label, i.e. ``<columnname> AS <name>``. - - This is a shortcut to the :func:`~.expression.label` function. - - if 'name' is None, an anonymous label name will be generated. - - """ - return Label(name, self, self.type) - - def desc(self): - """See :meth:`.ColumnOperators.desc`.""" - return desc(self) - - def asc(self): - """See :meth:`.ColumnOperators.asc`.""" - return asc(self) - - def nullsfirst(self): - """See :meth:`.ColumnOperators.nullsfirst`.""" - return nullsfirst(self) - - def nullslast(self): - """See :meth:`.ColumnOperators.nullslast`.""" - return nullslast(self) - - def distinct(self): + def _distinct_impl(self): """See :meth:`.ColumnOperators.distinct`.""" return UnaryExpression(self, operator=operators.distinct_op, type_=self.type) - def between(self, cleft, cright): + def _between_impl(self, cleft, cright): """See :meth:`.ColumnOperators.between`.""" return BinaryExpression( self, @@ -2123,23 +2053,52 @@ class CompareMixin(ColumnOperators): group=False), operators.between_op) - def collate(self, collation): - """See :meth:`.ColumnOperators.collate`.""" + def _collate_impl(self, expr, op, other): + return collate(expr, other) + + # a mapping of operators with the method they use, along with their negated + # operator for comparison operators + operators = { + "add": (__operate,), + "mul": (__operate,), + "sub": (__operate,), + "div": (__operate,), + "mod": (__operate,), + "truediv": (__operate,), + "custom_op": (__operate,), + "lt": (__compare, operators.ge), + "le": (__compare, operators.gt), + "ne": (__compare, operators.eq), + "gt": (__compare, operators.le), + "ge": (__compare, operators.lt), + "eq": (__compare, operators.ne), + "like_op": (__compare, operators.notlike_op), + "ilike_op": (__compare, operators.notilike_op), + "desc_op": (__scalar, desc), + "asc_op": (__scalar, asc), + "nullsfirst_op": (__scalar, nullsfirst), + "nullslast_op": (__scalar, nullslast), + "in_op": (_in_impl, operators.notin_op), + "collate": (_collate_impl,), + } + + def operate(self, expr, op, *other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, expr, op, *(other + o[1:]), **kwargs) + + def reverse_operate(self, expr, op, other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs) - return collate(self, collation) - def _bind_param(self, operator, obj): - return BindParameter(None, obj, - _compared_to_operator=operator, - _compared_to_type=self.type, unique=True) - def _check_literal(self, operator, other): + def _check_literal(self, expr, operator, other): if isinstance(other, BindParameter) and \ isinstance(other.type, sqltypes.NullType): # TODO: perhaps we should not mutate the incoming bindparam() # here and instead make a copy of it. this might # be the only place that we're mutating an incoming construct. - other.type = self.type + other.type = expr.type return other elif hasattr(other, '__clause_element__'): other = other.__clause_element__() @@ -2147,14 +2106,16 @@ class CompareMixin(ColumnOperators): other = other.as_scalar() return other elif not isinstance(other, ClauseElement): - return self._bind_param(operator, other) + return expr._bind_param(operator, other) elif isinstance(other, (SelectBase, Alias)): return other.as_scalar() else: return other +_DEFAULT_COMPARATOR = _DefaultColumnComparator() -class ColumnElement(ClauseElement, CompareMixin): + +class ColumnElement(ClauseElement, ColumnOperators): """Represent an element that is usable within the "column clause" portion of a ``SELECT`` statement. @@ -2203,6 +2164,32 @@ class ColumnElement(ClauseElement, CompareMixin): _key_label = None _alt_names = () + comparator = None + + class Comparator(operators.ColumnOperators): + def __init__(self, expr): + self.expr = expr + + def operate(self, op, *other, **kwargs): + return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, other, + **kwargs) + + def __getattr__(self, key): + if self.comparator is None: + raise AttributeError(key) + try: + return getattr(self.comparator, key) + except AttributeError: + raise AttributeError( + 'Neither %r object nor %r object has an attribute %r' % ( + type(self).__name__, + type(self.comparator).__name__, + key) + ) + @property def expression(self): """Return a column expression. @@ -2212,6 +2199,23 @@ class ColumnElement(ClauseElement, CompareMixin): """ return self + def operate(self, op, *other, **kwargs): + if self.comparator: + return op(self.comparator, *other, **kwargs) + else: + return _DEFAULT_COMPARATOR.operate(self, op, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + if self.comparator: + return op(other, self.comparator, **kwargs) + else: + return _DEFAULT_COMPARATOR.reverse_operate(self, op, *other, **kwargs) + + def _bind_param(self, operator, obj): + return BindParameter(None, obj, + _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + @property def _select_iterable(self): return (self, ) @@ -2292,6 +2296,16 @@ class ColumnElement(ClauseElement, CompareMixin): else: return False + def label(self, name): + """Produce a column label, i.e. ``<columnname> AS <name>``. + + This is a shortcut to the :func:`~.expression.label` function. + + if 'name' is None, an anonymous label name will be generated. + + """ + return Label(name, self, self.type) + @util.memoized_property def anon_label(self): """provides a constant 'anonymous label' for this ColumnElement. @@ -3544,7 +3558,7 @@ class BinaryExpression(ColumnElement): __visit_name__ = 'binary' def __init__(self, left, right, operator, type_=None, - negate=None, modifiers=None): + negate=None, modifiers=None, comparator_factory=None): # allow compatibility with libraries that # refer to BinaryExpression directly and pass strings if isinstance(operator, basestring): @@ -3554,6 +3568,11 @@ class BinaryExpression(ColumnElement): self.operator = operator self.type = sqltypes.to_instance(type_) self.negate = negate + + self.comparator_factory = comparator_factory + if comparator_factory is not None: + self.comparator = comparator_factory(self) + if modifiers is None: self.modifiers = {} else: @@ -3959,12 +3978,16 @@ class Grouping(ColumnElement): return getattr(self.element, attr) def __getstate__(self): - return {'element':self.element, 'type':self.type} + return {'element': self.element, 'type': self.type} def __setstate__(self, state): self.element = state['element'] self.type = state['type'] + def compare(self, other, **kw): + return isinstance(other, Grouping) and \ + self.element.compare(other.element) + class FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' @@ -4186,6 +4209,12 @@ class ColumnClause(Immutable, ColumnElement): :func:`literal_column()` function is usually used to create such a :class:`.ColumnClause`. + :param comparator_factory: a :class:`.operators.ColumnOperators` subclass + which will produce custom operator behavior. + + .. versionadded: 0.8 support for pluggable operators in + core column expressions. + """ __visit_name__ = 'column' @@ -4193,11 +4222,15 @@ class ColumnClause(Immutable, ColumnElement): _memoized_property = util.group_expirable_memoized_property() - def __init__(self, text, selectable=None, type_=None, is_literal=False): + def __init__(self, text, selectable=None, type_=None, is_literal=False, + comparator_factory=None): self.key = self.name = text self.table = selectable self.type = sqltypes.to_instance(type_) self.is_literal = is_literal + self.comparator_factory = comparator_factory + if comparator_factory: + self.comparator = comparator_factory(self) def _compare_name_for_result(self, other): if self.table is not None and hasattr(other, 'proxy_set'): @@ -4283,7 +4316,8 @@ class ColumnClause(Immutable, ColumnElement): _as_truncated(name if name else self.name), selectable=selectable, type_=self.type, - is_literal=is_literal + is_literal=is_literal, + comparator_factory=self.comparator_factory ) c.proxies = [self] if selectable._is_clone_of is not None: @@ -5916,7 +5950,6 @@ class ReleaseSavepointClause(_IdentifiedClause): # old names for compatibility _BindParamClause = BindParameter _Label = Label -_CompareMixin = CompareMixin _SelectBase = SelectBase _BinaryExpression = BinaryExpression _Cast = Cast |