diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-18 01:00:44 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-18 01:00:44 +0000 |
commit | 820346549b7e50e927c519c9bc54934e9a440422 (patch) | |
tree | 1747619176907dac7663f54408a4288129e49a3a /lib/sqlalchemy/sql.py | |
parent | 74595d900c23fcefe75353d3099cb73a55a0b6cf (diff) | |
download | sqlalchemy-820346549b7e50e927c519c9bc54934e9a440422.tar.gz |
- modified SQL operator functions to be module-level operators, allowing
SQL expressions to be pickleable [ticket:735]
- small adjustment to mapper class.__init__ to allow for Py2.6 object.__init__()
behavior
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 277 |
1 files changed, 100 insertions, 177 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index bad60bb3c..2f1cd4b91 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -24,9 +24,9 @@ are less guaranteed to stay the same in future releases. """ -from sqlalchemy import util, exceptions +from sqlalchemy import util, exceptions, operators from sqlalchemy import types as sqltypes -import re, operator +import re __all__ = ['Alias', 'ClauseElement', 'ClauseParameters', 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', @@ -47,7 +47,7 @@ def desc(column): order_by = [desc(table1.mycol)] """ - return _UnaryExpression(column, modifier=ColumnOperators.desc_op) + return _UnaryExpression(column, modifier=operators.desc_op) def asc(column): """Return an ascending ``ORDER BY`` clause element. @@ -56,7 +56,7 @@ def asc(column): order_by = [asc(table1.mycol)] """ - return _UnaryExpression(column, modifier=ColumnOperators.asc_op) + return _UnaryExpression(column, modifier=operators.asc_op) def outerjoin(left, right, onclause=None, **kwargs): """Return an ``OUTER JOIN`` clause element. @@ -337,7 +337,7 @@ def and_(*clauses): """ if len(clauses) == 1: return clauses[0] - return ClauseList(operator=operator.and_, *clauses) + return ClauseList(operator=operators.and_, *clauses) def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. @@ -348,7 +348,7 @@ def or_(*clauses): if len(clauses) == 1: return clauses[0] - return ClauseList(operator=operator.or_, *clauses) + return ClauseList(operator=operators.or_, *clauses) def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. @@ -357,12 +357,12 @@ def not_(clause): subclasses to produce the same result. """ - return operator.inv(clause) + return operators.inv(clause) def distinct(expr): """return a ``DISTINCT`` clause.""" - return _UnaryExpression(expr, operator=ColumnOperators.distinct_op) + return _UnaryExpression(expr, operator=operators.distinct_op) def between(ctest, cleft, cright): """Return a ``BETWEEN`` predicate clause. @@ -374,7 +374,7 @@ def between(ctest, cleft, cright): """ ctest = _literal_as_binds(ctest) - return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op) + return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operators.and_, group=False), operators.between_op) def case(whens, value=None, else_=None): @@ -420,7 +420,7 @@ def cast(clause, totype, **kwargs): def extract(field, expr): """Return the clause ``extract(field FROM expr)``.""" - expr = _BinaryExpression(text(field), expr, Operators.from_) + expr = _BinaryExpression(text(field), expr, operators.from_) return func.extract(expr) def exists(*args, **kwargs): @@ -1195,38 +1195,18 @@ class ClauseElement(object): if hasattr(self, 'negation_clause'): return self.negation_clause else: - return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None) + return _UnaryExpression(self.self_group(against=operators.inv), operator=operators.inv, negate=None) class Operators(object): - def from_(): - raise NotImplementedError() - from_ = staticmethod(from_) - - def as_(): - raise NotImplementedError() - as_ = staticmethod(as_) - - def exists(): - raise NotImplementedError() - exists = staticmethod(exists) - - def is_(): - raise NotImplementedError() - is_ = staticmethod(is_) - - def isnot(): - raise NotImplementedError() - isnot = staticmethod(isnot) - def __and__(self, other): - return self.operate(operator.and_, other) + return self.operate(operators.and_, other) def __or__(self, other): - return self.operate(operator.or_, other) + return self.operate(operators.or_, other) def __invert__(self): - return self.operate(operator.inv) + return self.operate(operators.inv) def clause_element(self): raise NotImplementedError() @@ -1239,137 +1219,80 @@ class Operators(object): class ColumnOperators(Operators): """defines comparison and math operations""" - - def like_op(a, b): - return a.like(b) - like_op = staticmethod(like_op) - - def notlike_op(a, b): - raise NotImplementedError() - notlike_op = staticmethod(notlike_op) - - def ilike_op(a, b): - return a.ilike(b) - ilike_op = staticmethod(ilike_op) - - def notilike_op(a, b): - raise NotImplementedError() - notilike_op = staticmethod(notilike_op) - - def between_op(a, b): - return a.between(b) - between_op = staticmethod(between_op) - - def in_op(a, b): - return a.in_(*b) - in_op = staticmethod(in_op) - - def notin_op(a, b): - raise NotImplementedError() - notin_op = staticmethod(notin_op) - - def distinct_op(a): - return a.distinct() - distinct_op = staticmethod(distinct_op) - - def startswith_op(a, b): - return a.startswith(b) - startswith_op = staticmethod(startswith_op) - - def endswith_op(a, b): - return a.endswith(b) - endswith_op = staticmethod(endswith_op) - - def comma_op(a, b): - raise NotImplementedError() - comma_op = staticmethod(comma_op) - - def concat_op(a, b): - return a.concat(b) - concat_op = staticmethod(concat_op) - - def desc_op(a): - return a.desc() - desc_op = staticmethod(desc_op) - - def asc_op(a): - return a.asc() - asc_op = staticmethod(asc_op) - def __lt__(self, other): - return self.operate(operator.lt, other) + return self.operate(operators.lt, other) def __le__(self, other): - return self.operate(operator.le, other) + return self.operate(operators.le, other) def __eq__(self, other): - return self.operate(operator.eq, other) + return self.operate(operators.eq, other) def __ne__(self, other): - return self.operate(operator.ne, other) + return self.operate(operators.ne, other) def __gt__(self, other): - return self.operate(operator.gt, other) + return self.operate(operators.gt, other) def __ge__(self, other): - return self.operate(operator.ge, other) + return self.operate(operators.ge, other) def concat(self, other): - return self.operate(ColumnOperators.concat_op, other) + return self.operate(operators.concat_op, other) def like(self, other): - return self.operate(ColumnOperators.like_op, other) + return self.operate(operators.like_op, other) def in_(self, *other): - return self.operate(ColumnOperators.in_op, other) + return self.operate(operators.in_op, other) def startswith(self, other): - return self.operate(ColumnOperators.startswith_op, other) + return self.operate(operators.startswith_op, other) def endswith(self, other): - return self.operate(ColumnOperators.endswith_op, other) + return self.operate(operators.endswith_op, other) def desc(self): - return self.operate(ColumnOperators.desc_op) + return self.operate(operators.desc_op) def asc(self): - return self.operate(ColumnOperators.asc_op) + return self.operate(operators.asc_op) def __radd__(self, other): - return self.reverse_operate(operator.add, other) + return self.reverse_operate(operators.add, other) def __rsub__(self, other): - return self.reverse_operate(operator.sub, other) + return self.reverse_operate(operators.sub, other) def __rmul__(self, other): - return self.reverse_operate(operator.mul, other) + return self.reverse_operate(operators.mul, other) def __rdiv__(self, other): - return self.reverse_operate(operator.div, other) + return self.reverse_operate(operators.div, other) def between(self, cleft, cright): - return self.operate(ColumnOperators.between_op, cleft, cright) + return self.operate(operators.between_op, cleft, cright) def distinct(self): - return self.operate(ColumnOperators.distinct_op) + return self.operate(operators.distinct_op) def __add__(self, other): - return self.operate(operator.add, other) + return self.operate(operators.add, other) def __sub__(self, other): - return self.operate(operator.sub, other) + return self.operate(operators.sub, other) def __mul__(self, other): - return self.operate(operator.mul, other) + return self.operate(operators.mul, other) def __div__(self, other): - return self.operate(operator.div, other) + return self.operate(operators.div, other) def __mod__(self, other): - return self.operate(operator.mod, other) + return self.operate(operators.mod, other) def __truediv__(self, other): - return self.operate(operator.truediv, other) + return self.operate(operators.truediv, other) # precedence ordering for common operators. if an operator is not present in this list, # it will be parenthesized when grouped against other operators @@ -1377,35 +1300,35 @@ _smallest = object() _largest = object() PRECEDENCE = { - Operators.from_:15, - operator.mul:7, - operator.div:7, - operator.mod:7, - operator.add:6, - operator.sub:6, - ColumnOperators.concat_op:6, - ColumnOperators.ilike_op:5, - ColumnOperators.notilike_op:5, - ColumnOperators.like_op:5, - ColumnOperators.notlike_op:5, - ColumnOperators.in_op:5, - ColumnOperators.notin_op:5, - Operators.is_:5, - Operators.isnot:5, - operator.eq:5, - operator.ne:5, - operator.gt:5, - operator.lt:5, - operator.ge:5, - operator.le:5, - ColumnOperators.between_op:5, - ColumnOperators.distinct_op:5, - operator.inv:4, - operator.and_:3, - operator.or_:2, - ColumnOperators.comma_op:-1, - Operators.as_:-1, - Operators.exists:0, + operators.from_:15, + operators.mul:7, + operators.div:7, + operators.mod:7, + operators.add:6, + operators.sub:6, + operators.concat_op:6, + operators.ilike_op:5, + operators.notilike_op:5, + operators.like_op:5, + operators.notlike_op:5, + operators.in_op:5, + operators.notin_op:5, + operators.is_:5, + operators.isnot:5, + operators.eq:5, + operators.ne:5, + operators.gt:5, + operators.lt:5, + operators.ge:5, + operators.le:5, + operators.between_op:5, + operators.distinct_op:5, + operators.inv:4, + operators.and_:3, + operators.or_:2, + operators.comma_op:-1, + operators.as_:-1, + operators.exists:0, _smallest: -1000, _largest: 1000 } @@ -1415,10 +1338,10 @@ class _CompareMixin(ColumnOperators): def __compare(self, op, obj, negate=None): if obj is None or isinstance(obj, _Null): - if op == operator.eq: - return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot) - elif op == operator.ne: - return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_) + if op == operators.eq: + return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot) + elif op == operators.ne: + return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_) else: raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") else: @@ -1433,25 +1356,25 @@ class _CompareMixin(ColumnOperators): type_ = self._compare_type(obj) # TODO: generalize operator overloading like this out into the types module - if op == operator.add and isinstance(type_, (sqltypes.Concatenable)): - op = ColumnOperators.concat_op + if op == operators.add and isinstance(type_, (sqltypes.Concatenable)): + op = operators.concat_op return _BinaryExpression(self.expression_element(), obj, op, type_=type_) operators = { - operator.add : (__operate,), - operator.mul : (__operate,), - operator.sub : (__operate,), - operator.div : (__operate,), - operator.mod : (__operate,), - operator.truediv : (__operate,), - operator.lt : (__compare, operator.ge), - operator.le : (__compare, operator.gt), - operator.ne : (__compare, operator.eq), - operator.gt : (__compare, operator.le), - operator.ge : (__compare, operator.lt), - operator.eq : (__compare, operator.ne), - ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op), + operators.add : (__operate,), + operators.mul : (__operate,), + operators.sub : (__operate,), + operators.div : (__operate,), + operators.mod : (__operate,), + operators.truediv : (__operate,), + operators.lt : (__compare, operators.ge), + operators.le : (__compare, operators.gt), + operators.ne : (__compare, operators.eq), + operators.gt : (__compare, operators.le), + operators.ge : (__compare, operators.lt), + operators.eq : (__compare, operators.ne), + operators.like_op : (__compare, operators.notlike_op), } def operate(self, op, *other): @@ -1462,7 +1385,7 @@ class _CompareMixin(ColumnOperators): return self._bind_param(other).operate(op, self) def in_(self, *other): - return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other) + return self._in_impl(operators.in_op, operators.notin_op, *other) def _in_impl(self, op, negate_op, *other): if len(other) == 0: @@ -1489,7 +1412,7 @@ class _CompareMixin(ColumnOperators): """produce the clause ``LIKE '<other>%'``""" perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String) - return self.__compare(ColumnOperators.like_op, other + perc) + return self.__compare(operators.like_op, other + perc) def endswith(self, other): """produce the clause ``LIKE '%<other>'``""" @@ -1498,7 +1421,7 @@ class _CompareMixin(ColumnOperators): else: po = literal('%', type_=sqltypes.String) + other po.type = sqltypes.to_instance(sqltypes.String) #force! - return self.__compare(ColumnOperators.like_op, po) + return self.__compare(operators.like_op, po) def label(self, name): """produce a column label, i.e. ``<columnname> AS <name>``""" @@ -1516,12 +1439,12 @@ class _CompareMixin(ColumnOperators): def distinct(self): """produce a DISTINCT clause, i.e. ``DISTINCT <columnname>``""" - return _UnaryExpression(self, operator=ColumnOperators.distinct_op) + return _UnaryExpression(self, operator=operators.distinct_op) def between(self, cleft, cright): """produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``""" - return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op) + return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operators.and_, group=False), operators.between_op) def op(self, operator): """produce a generic operator function. @@ -2132,7 +2055,7 @@ class ClauseList(ClauseElement): def __init__(self, *clauses, **kwargs): self.clauses = [] - self.operator = kwargs.pop('operator', ColumnOperators.comma_op) + self.operator = kwargs.pop('operator', operators.comma_op) self.group = kwargs.pop('group', True) self.group_contents = kwargs.pop('group_contents', True) for c in clauses: @@ -2248,7 +2171,7 @@ class _Function(_CalculatedClause, FromClause): def __init__(self, name, *clauses, **kwargs): self.packagenames = kwargs.get('packagenames', None) or [] - kwargs['operator'] = ColumnOperators.comma_op + kwargs['operator'] = operators.comma_op _CalculatedClause.__init__(self, name, **kwargs) for c in clauses: self.append(c) @@ -2365,7 +2288,7 @@ class _BinaryExpression(ColumnElement): ( self.left.compare(other.left) and self.right.compare(other.right) or ( - self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and + self.operator in [operators.eq, operators.ne, operators.add, operators.mul] and self.left.compare(other.right) and self.right.compare(other.left) ) ) @@ -2390,7 +2313,7 @@ class _Exists(_UnaryExpression): def __init__(self, *args, **kwargs): kwargs['correlate'] = True s = select(*args, **kwargs).as_scalar().self_group() - _UnaryExpression.__init__(self, s, operator=Operators.exists) + _UnaryExpression.__init__(self, s, operator=operators.exists) def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) @@ -2444,7 +2367,7 @@ class Join(FromClause): class BinaryVisitor(ClauseVisitor): def visit_binary(self, binary): - if binary.operator == operator.eq: + if binary.operator == operators.eq: add_equiv(binary.left, binary.right) BinaryVisitor().traverse(self.onclause) @@ -2526,7 +2449,7 @@ class Join(FromClause): equivs = util.Set() class LocateEquivs(NoColumnVisitor): def visit_binary(self, binary): - if binary.operator == operator.eq and binary.left.name == binary.right.name: + if binary.operator == operators.eq and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) LocateEquivs().traverse(self.onclause) @@ -2739,7 +2662,7 @@ class _Label(ColumnElement): obj = obj.obj self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) - self.obj = obj.self_group(against=Operators.as_) + self.obj = obj.self_group(against=operators.as_) self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) key = property(lambda s: s.name) @@ -3314,7 +3237,7 @@ class Select(_SelectBaseMixin, FromClause): column = _literal_as_column(column) if isinstance(column, _ScalarSelect): - column = column.self_group(against=ColumnOperators.comma_op) + column = column.self_group(against=operators.comma_op) self._raw_columns.append(column) |