summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/expression.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r--lib/sqlalchemy/sql/expression.py295
1 files changed, 164 insertions, 131 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 2c1c047b5..84d7c1a29 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1773,32 +1773,6 @@ class ClauseElement(Visitable):
return self
- @util.deprecated('0.7',
- 'Only SQL expressions which subclass '
- ':class:`.Executable` may provide the '
- ':func:`.execute` method.')
- def execute(self, *multiparams, **params):
- """Compile and execute this :class:`.ClauseElement`.
-
- """
- e = self.bind
- if e is None:
- label = getattr(self, 'description', self.__class__.__name__)
- msg = ('This %s does not support direct execution.' % label)
- raise exc.UnboundExecutionError(msg)
- return e._execute_clauseelement(self, multiparams, params)
-
- @util.deprecated('0.7',
- 'Only SQL expressions which subclass '
- ':class:`.Executable` may provide the '
- ':func:`.scalar` method.')
- def scalar(self, *multiparams, **params):
- """Compile and execute this :class:`.ClauseElement`, returning
- the result's scalar representation.
-
- """
- return self.execute(*multiparams, **params).scalar()
-
def compile(self, bind=None, dialect=None, **kw):
"""Compile this SQL expression.
@@ -1901,7 +1875,7 @@ class Immutable(object):
return self
-class CompareMixin(ColumnOperators):
+class _DefaultColumnComparator(ColumnOperators):
"""Defines comparison and math operations.
The :class:`.CompareMixin` is part of the interface provided
@@ -1913,91 +1887,74 @@ class CompareMixin(ColumnOperators):
"""
- def __compare(self, op, obj, negate=None, reverse=False,
+ def __compare(self, expr, op, obj, negate=None, reverse=False,
**kwargs
):
if obj is None or isinstance(obj, Null):
if op == operators.eq:
- return BinaryExpression(self, null(), operators.is_,
+ return BinaryExpression(expr, null(), operators.is_,
negate=operators.isnot)
elif op == operators.ne:
- return BinaryExpression(self, null(), operators.isnot,
+ return BinaryExpression(expr, null(), operators.isnot,
negate=operators.is_)
else:
raise exc.ArgumentError("Only '='/'!=' operators can "
"be used with NULL")
else:
- obj = self._check_literal(op, obj)
+ obj = self._check_literal(expr, op, obj)
if reverse:
return BinaryExpression(obj,
- self,
+ expr,
op,
type_=sqltypes.BOOLEANTYPE,
negate=negate, modifiers=kwargs)
else:
- return BinaryExpression(self,
+ return BinaryExpression(expr,
obj,
op,
type_=sqltypes.BOOLEANTYPE,
negate=negate, modifiers=kwargs)
- def __operate(self, op, obj, reverse=False):
- obj = self._check_literal(op, obj)
-
+ def __operate(self, expr, op, obj, reverse=False):
+ obj = self._check_literal(expr, op, obj)
+ comparator_factory = None
if reverse:
- left, right = obj, self
+ left, right = obj, expr
else:
- left, right = self, obj
+ left, right = expr, obj
if left.type is None:
op, result_type = sqltypes.NULLTYPE._adapt_expression(op,
right.type)
+ result_type = sqltypes.to_instance(result_type)
+ if right.type._compare_type_affinity(result_type):
+ comparator_factory = right.comparator_factory
elif right.type is None:
op, result_type = left.type._adapt_expression(op,
sqltypes.NULLTYPE)
+ result_type = sqltypes.to_instance(result_type)
+ if left.type._compare_type_affinity(result_type):
+ comparator_factory = left.comparator_factory
else:
- op, result_type = left.type._adapt_expression(op,
- right.type)
- return BinaryExpression(left, right, op, type_=result_type)
+ op, result_type = left.type._adapt_expression(op, right.type)
+ result_type = sqltypes.to_instance(result_type)
+ if left.type._compare_type_affinity(result_type):
+ comparator_factory = left.comparator_factory
+ elif right.type._compare_type_affinity(result_type):
+ comparator_factory = right.comparator_factory
- # a mapping of operators with the method they use, along with their negated
- # operator for comparison operators
- operators = {
- "add": (__operate,),
- "mul": (__operate,),
- "sub": (__operate,),
- "div": (__operate,),
- "mod": (__operate,),
- "truediv": (__operate,),
- "custom_op": (__operate,),
- "lt": (__compare, operators.ge),
- "le": (__compare, operators.gt),
- "ne": (__compare, operators.eq),
- "gt": (__compare, operators.le),
- "ge": (__compare, operators.lt),
- "eq": (__compare, operators.ne),
- "like_op": (__compare, operators.notlike_op),
- "ilike_op": (__compare, operators.notilike_op),
- }
+ return BinaryExpression(left, right, op, type_=result_type,
+ comparator_factory=comparator_factory)
- def operate(self, op, *other, **kwargs):
- o = CompareMixin.operators[op.__name__]
- return o[0](self, op, other[0], *o[1:], **kwargs)
+ def __scalar(self, expr, op, fn, **kw):
+ return fn(expr)
- def reverse_operate(self, op, other, **kwargs):
- o = CompareMixin.operators[op.__name__]
- return o[0](self, op, other, reverse=True, *o[1:], **kwargs)
-
- def in_(self, other):
- """See :meth:`.ColumnOperators.in_`."""
- return self._in_impl(operators.in_op, operators.notin_op, other)
-
- def _in_impl(self, op, negate_op, seq_or_selectable):
+ def _in_impl(self, expr, op, seq_or_selectable, negate_op):
seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
if isinstance(seq_or_selectable, ScalarSelect):
- return self.__compare(op, seq_or_selectable,
+ return self.__compare(expr, op, seq_or_selectable,
negate=negate_op)
elif isinstance(seq_or_selectable, SelectBase):
@@ -2006,10 +1963,10 @@ class CompareMixin(ColumnOperators):
# as_scalar() to produce a multi- column selectable that
# does not export itself as a FROM clause
- return self.__compare(op, seq_or_selectable.as_scalar(),
+ return self.__compare(expr, op, seq_or_selectable.as_scalar(),
negate=negate_op)
elif isinstance(seq_or_selectable, (Selectable, TextClause)):
- return self.__compare(op, seq_or_selectable,
+ return self.__compare(expr, op, seq_or_selectable,
negate=negate_op)
@@ -2018,12 +1975,12 @@ class CompareMixin(ColumnOperators):
args = []
for o in seq_or_selectable:
if not _is_literal(o):
- if not isinstance(o, CompareMixin):
+ if not isinstance(o, ColumnOperators):
raise exc.InvalidRequestError('in() function accept'
's either a list of non-selectable values, '
'or a selectable: %r' % o)
else:
- o = self._bind_param(op, o)
+ o = expr._bind_param(op, o)
args.append(o)
if len(args) == 0:
@@ -2037,18 +1994,17 @@ class CompareMixin(ColumnOperators):
'empty sequence. This results in a '
'contradiction, which nonetheless can be '
'expensive to evaluate. Consider alternative '
- 'strategies for improved performance.' % self)
- return self != self
+ 'strategies for improved performance.' % expr)
+ return expr != expr
- return self.__compare(op,
+ return self.__compare(expr, op,
ClauseList(*args).self_group(against=op),
negate=negate_op)
-
- def __neg__(self):
+ def _neg_impl(self):
"""See :meth:`.ColumnOperators.__neg__`."""
- return UnaryExpression(self, operator=operators.neg)
+ return UnaryExpression(self.expr, operator=operators.neg)
- def startswith(self, other, escape=None):
+ def _startswith_impl(self, other, escape=None):
"""See :meth:`.ColumnOperators.startswith`."""
# use __radd__ to force string concat behavior
return self.__compare(
@@ -2058,7 +2014,7 @@ class CompareMixin(ColumnOperators):
),
escape=escape)
- def endswith(self, other, escape=None):
+ def _endswith_impl(self, other, escape=None):
"""See :meth:`.ColumnOperators.endswith`."""
return self.__compare(
operators.like_op,
@@ -2066,7 +2022,7 @@ class CompareMixin(ColumnOperators):
self._check_literal(operators.like_op, other),
escape=escape)
- def contains(self, other, escape=None):
+ def _contains_impl(self, other, escape=None):
"""See :meth:`.ColumnOperators.contains`."""
return self.__compare(
operators.like_op,
@@ -2075,44 +2031,18 @@ class CompareMixin(ColumnOperators):
literal_column("'%'", type_=sqltypes.String),
escape=escape)
- def match(self, other):
+ def _match_impl(self, other):
"""See :meth:`.ColumnOperators.match`."""
return self.__compare(operators.match_op,
self._check_literal(operators.match_op,
other))
- def label(self, name):
- """Produce a column label, i.e. ``<columnname> AS <name>``.
-
- This is a shortcut to the :func:`~.expression.label` function.
-
- if 'name' is None, an anonymous label name will be generated.
-
- """
- return Label(name, self, self.type)
-
- def desc(self):
- """See :meth:`.ColumnOperators.desc`."""
- return desc(self)
-
- def asc(self):
- """See :meth:`.ColumnOperators.asc`."""
- return asc(self)
-
- def nullsfirst(self):
- """See :meth:`.ColumnOperators.nullsfirst`."""
- return nullsfirst(self)
-
- def nullslast(self):
- """See :meth:`.ColumnOperators.nullslast`."""
- return nullslast(self)
-
- def distinct(self):
+ def _distinct_impl(self):
"""See :meth:`.ColumnOperators.distinct`."""
return UnaryExpression(self, operator=operators.distinct_op,
type_=self.type)
- def between(self, cleft, cright):
+ def _between_impl(self, cleft, cright):
"""See :meth:`.ColumnOperators.between`."""
return BinaryExpression(
self,
@@ -2123,23 +2053,52 @@ class CompareMixin(ColumnOperators):
group=False),
operators.between_op)
- def collate(self, collation):
- """See :meth:`.ColumnOperators.collate`."""
+ def _collate_impl(self, expr, op, other):
+ return collate(expr, other)
+
+ # a mapping of operators with the method they use, along with their negated
+ # operator for comparison operators
+ operators = {
+ "add": (__operate,),
+ "mul": (__operate,),
+ "sub": (__operate,),
+ "div": (__operate,),
+ "mod": (__operate,),
+ "truediv": (__operate,),
+ "custom_op": (__operate,),
+ "lt": (__compare, operators.ge),
+ "le": (__compare, operators.gt),
+ "ne": (__compare, operators.eq),
+ "gt": (__compare, operators.le),
+ "ge": (__compare, operators.lt),
+ "eq": (__compare, operators.ne),
+ "like_op": (__compare, operators.notlike_op),
+ "ilike_op": (__compare, operators.notilike_op),
+ "desc_op": (__scalar, desc),
+ "asc_op": (__scalar, asc),
+ "nullsfirst_op": (__scalar, nullsfirst),
+ "nullslast_op": (__scalar, nullslast),
+ "in_op": (_in_impl, operators.notin_op),
+ "collate": (_collate_impl,),
+ }
+
+ def operate(self, expr, op, *other, **kwargs):
+ o = self.operators[op.__name__]
+ return o[0](self, expr, op, *(other + o[1:]), **kwargs)
+
+ def reverse_operate(self, expr, op, other, **kwargs):
+ o = self.operators[op.__name__]
+ return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs)
- return collate(self, collation)
- def _bind_param(self, operator, obj):
- return BindParameter(None, obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type, unique=True)
- def _check_literal(self, operator, other):
+ def _check_literal(self, expr, operator, other):
if isinstance(other, BindParameter) and \
isinstance(other.type, sqltypes.NullType):
# TODO: perhaps we should not mutate the incoming bindparam()
# here and instead make a copy of it. this might
# be the only place that we're mutating an incoming construct.
- other.type = self.type
+ other.type = expr.type
return other
elif hasattr(other, '__clause_element__'):
other = other.__clause_element__()
@@ -2147,14 +2106,16 @@ class CompareMixin(ColumnOperators):
other = other.as_scalar()
return other
elif not isinstance(other, ClauseElement):
- return self._bind_param(operator, other)
+ return expr._bind_param(operator, other)
elif isinstance(other, (SelectBase, Alias)):
return other.as_scalar()
else:
return other
+_DEFAULT_COMPARATOR = _DefaultColumnComparator()
-class ColumnElement(ClauseElement, CompareMixin):
+
+class ColumnElement(ClauseElement, ColumnOperators):
"""Represent an element that is usable within the "column clause" portion
of a ``SELECT`` statement.
@@ -2203,6 +2164,32 @@ class ColumnElement(ClauseElement, CompareMixin):
_key_label = None
_alt_names = ()
+ comparator = None
+
+ class Comparator(operators.ColumnOperators):
+ def __init__(self, expr):
+ self.expr = expr
+
+ def operate(self, op, *other, **kwargs):
+ return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, other,
+ **kwargs)
+
+ def __getattr__(self, key):
+ if self.comparator is None:
+ raise AttributeError(key)
+ try:
+ return getattr(self.comparator, key)
+ except AttributeError:
+ raise AttributeError(
+ 'Neither %r object nor %r object has an attribute %r' % (
+ type(self).__name__,
+ type(self.comparator).__name__,
+ key)
+ )
+
@property
def expression(self):
"""Return a column expression.
@@ -2212,6 +2199,23 @@ class ColumnElement(ClauseElement, CompareMixin):
"""
return self
+ def operate(self, op, *other, **kwargs):
+ if self.comparator:
+ return op(self.comparator, *other, **kwargs)
+ else:
+ return _DEFAULT_COMPARATOR.operate(self, op, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ if self.comparator:
+ return op(other, self.comparator, **kwargs)
+ else:
+ return _DEFAULT_COMPARATOR.reverse_operate(self, op, *other, **kwargs)
+
+ def _bind_param(self, operator, obj):
+ return BindParameter(None, obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type, unique=True)
+
@property
def _select_iterable(self):
return (self, )
@@ -2292,6 +2296,16 @@ class ColumnElement(ClauseElement, CompareMixin):
else:
return False
+ def label(self, name):
+ """Produce a column label, i.e. ``<columnname> AS <name>``.
+
+ This is a shortcut to the :func:`~.expression.label` function.
+
+ if 'name' is None, an anonymous label name will be generated.
+
+ """
+ return Label(name, self, self.type)
+
@util.memoized_property
def anon_label(self):
"""provides a constant 'anonymous label' for this ColumnElement.
@@ -3544,7 +3558,7 @@ class BinaryExpression(ColumnElement):
__visit_name__ = 'binary'
def __init__(self, left, right, operator, type_=None,
- negate=None, modifiers=None):
+ negate=None, modifiers=None, comparator_factory=None):
# allow compatibility with libraries that
# refer to BinaryExpression directly and pass strings
if isinstance(operator, basestring):
@@ -3554,6 +3568,11 @@ class BinaryExpression(ColumnElement):
self.operator = operator
self.type = sqltypes.to_instance(type_)
self.negate = negate
+
+ self.comparator_factory = comparator_factory
+ if comparator_factory is not None:
+ self.comparator = comparator_factory(self)
+
if modifiers is None:
self.modifiers = {}
else:
@@ -3959,12 +3978,16 @@ class Grouping(ColumnElement):
return getattr(self.element, attr)
def __getstate__(self):
- return {'element':self.element, 'type':self.type}
+ return {'element': self.element, 'type': self.type}
def __setstate__(self, state):
self.element = state['element']
self.type = state['type']
+ def compare(self, other, **kw):
+ return isinstance(other, Grouping) and \
+ self.element.compare(other.element)
+
class FromGrouping(FromClause):
"""Represent a grouping of a FROM clause"""
__visit_name__ = 'grouping'
@@ -4186,6 +4209,12 @@ class ColumnClause(Immutable, ColumnElement):
:func:`literal_column()` function is usually used to create such a
:class:`.ColumnClause`.
+ :param comparator_factory: a :class:`.operators.ColumnOperators` subclass
+ which will produce custom operator behavior.
+
+ .. versionadded: 0.8 support for pluggable operators in
+ core column expressions.
+
"""
__visit_name__ = 'column'
@@ -4193,11 +4222,15 @@ class ColumnClause(Immutable, ColumnElement):
_memoized_property = util.group_expirable_memoized_property()
- def __init__(self, text, selectable=None, type_=None, is_literal=False):
+ def __init__(self, text, selectable=None, type_=None, is_literal=False,
+ comparator_factory=None):
self.key = self.name = text
self.table = selectable
self.type = sqltypes.to_instance(type_)
self.is_literal = is_literal
+ self.comparator_factory = comparator_factory
+ if comparator_factory:
+ self.comparator = comparator_factory(self)
def _compare_name_for_result(self, other):
if self.table is not None and hasattr(other, 'proxy_set'):
@@ -4283,7 +4316,8 @@ class ColumnClause(Immutable, ColumnElement):
_as_truncated(name if name else self.name),
selectable=selectable,
type_=self.type,
- is_literal=is_literal
+ is_literal=is_literal,
+ comparator_factory=self.comparator_factory
)
c.proxies = [self]
if selectable._is_clone_of is not None:
@@ -5916,7 +5950,6 @@ class ReleaseSavepointClause(_IdentifiedClause):
# old names for compatibility
_BindParamClause = BindParameter
_Label = Label
-_CompareMixin = CompareMixin
_SelectBase = SelectBase
_BinaryExpression = BinaryExpression
_Cast = Cast