diff options
-rw-r--r-- | CHANGES | 6 | ||||
-rw-r--r-- | examples/sharding/attribute_shard.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 70 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/operators.py | 61 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 277 | ||||
-rw-r--r-- | test/orm/sharding/shard.py | 8 | ||||
-rw-r--r-- | test/sql/select.py | 5 |
9 files changed, 222 insertions, 226 deletions
@@ -24,6 +24,12 @@ CHANGES connection; a close() will ensure that connection transactional state is the same as that which existed on it before being bound to the Session. +- modified SQL operator functions to be module-level operators, allowing + SQL expressions to be pickleable [ticket:735] + +- small adjustment to mapper class.__init__ to allow for Py2.6 object.__init__() + behavior + 0.4.0beta3 ---------- diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py index df3f7467f..25da98872 100644 --- a/examples/sharding/attribute_shard.py +++ b/examples/sharding/attribute_shard.py @@ -21,8 +21,8 @@ To set up a sharding system, you need: from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.shard import ShardedSession -from sqlalchemy.sql import ColumnOperators -import datetime, operator +from sqlalchemy.sql import operators +import datetime # step 2. databases echo = True @@ -133,9 +133,9 @@ def query_chooser(query): class FindContinent(sql.ClauseVisitor): def visit_binary(self, binary): if binary.left is weather_locations.c.continent: - if binary.operator == operator.eq: + if binary.operator == operators.eq: ids.append(shard_lookup[binary.right.value]) - elif binary.operator == ColumnOperators.in_op: + elif binary.operator == operators.in_op: for bind in binary.right.clauses: ids.append(shard_lookup[bind.value]) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index dd4065f39..5f5e1c171 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -12,7 +12,7 @@ module. import string, re, sets, operator -from sqlalchemy import schema, sql, engine, util, exceptions +from sqlalchemy import schema, sql, engine, util, exceptions, operators from sqlalchemy.engine import default @@ -50,39 +50,39 @@ BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE) ANONYMOUS_LABEL = re.compile(r'{ANON (-?\d+) (.*)}') OPERATORS = { - operator.and_ : 'AND', - operator.or_ : 'OR', - operator.inv : 'NOT', - operator.add : '+', - operator.mul : '*', - operator.sub : '-', - operator.div : '/', - operator.mod : '%', - operator.truediv : '/', - operator.lt : '<', - operator.le : '<=', - operator.ne : '!=', - operator.gt : '>', - operator.ge : '>=', - operator.eq : '=', - sql.ColumnOperators.distinct_op : 'DISTINCT', - sql.ColumnOperators.concat_op : '||', - sql.ColumnOperators.like_op : 'LIKE', - sql.ColumnOperators.notlike_op : 'NOT LIKE', - sql.ColumnOperators.ilike_op : 'ILIKE', - sql.ColumnOperators.notilike_op : 'NOT ILIKE', - sql.ColumnOperators.between_op : 'BETWEEN', - sql.ColumnOperators.in_op : 'IN', - sql.ColumnOperators.notin_op : 'NOT IN', - sql.ColumnOperators.comma_op : ', ', - sql.ColumnOperators.desc_op : 'DESC', - sql.ColumnOperators.asc_op : 'ASC', + operators.and_ : 'AND', + operators.or_ : 'OR', + operators.inv : 'NOT', + operators.add : '+', + operators.mul : '*', + operators.sub : '-', + operators.div : '/', + operators.mod : '%', + operators.truediv : '/', + operators.lt : '<', + operators.le : '<=', + operators.ne : '!=', + operators.gt : '>', + operators.ge : '>=', + operators.eq : '=', + operators.distinct_op : 'DISTINCT', + operators.concat_op : '||', + operators.like_op : 'LIKE', + operators.notlike_op : 'NOT LIKE', + operators.ilike_op : 'ILIKE', + operators.notilike_op : 'NOT ILIKE', + operators.between_op : 'BETWEEN', + operators.in_op : 'IN', + operators.notin_op : 'NOT IN', + operators.comma_op : ', ', + operators.desc_op : 'DESC', + operators.asc_op : 'ASC', - sql.Operators.from_ : 'FROM', - sql.Operators.as_ : 'AS', - sql.Operators.exists : 'EXISTS', - sql.Operators.is_ : 'IS', - sql.Operators.isnot : 'IS NOT' + operators.from_ : 'FROM', + operators.as_ : 'AS', + operators.exists : 'EXISTS', + operators.is_ : 'IS', + operators.isnot : 'IS NOT' } class ANSIDialect(default.DefaultDialect): @@ -284,7 +284,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname self.column_labels[label.name] = labelname - return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)]) + return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) def visit_column(self, column, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily @@ -343,7 +343,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): sep = clauselist.operator if sep is None: sep = " " - elif sep == sql.ColumnOperators.comma_op: + elif sep == operators.comma_op: sep = ', ' else: sep = " " + self.operator_string(clauselist.operator) + " " diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 4cdc7962a..303a44552 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -123,10 +123,12 @@ Notes page on the wiki at http://sqlalchemy.org is a good resource for timely information affecting MySQL in SQLAlchemy. """ -import re, datetime, inspect, warnings, operator, sys +import re, datetime, inspect, warnings, sys from array import array as _array from sqlalchemy import ansisql, exceptions, logging, schema, sql, util +from sqlalchemy import operators as sql_operators + from sqlalchemy.engine import base as engine_base, default import sqlalchemy.types as sqltypes @@ -1735,9 +1737,9 @@ class MySQLCompiler(ansisql.ANSICompiler): operators = ansisql.ANSICompiler.operators.copy() operators.update( { - sql.ColumnOperators.concat_op: \ + sql_operators.concat_op: \ lambda x, y: "concat(%s, %s)" % (x, y), - operator.mod: '%%' + sql_operators.mod: '%%' } ) diff --git a/lib/sqlalchemy/operators.py b/lib/sqlalchemy/operators.py new file mode 100644 index 000000000..b8aca3d26 --- /dev/null +++ b/lib/sqlalchemy/operators.py @@ -0,0 +1,61 @@ +"""define opeators used in SQL expressions""" + +from operator import and_, or_, inv, add, mul, sub, div, mod, truediv, lt, le, ne, gt, ge, eq + +def from_(): + raise NotImplementedError() + +def as_(): + raise NotImplementedError() + +def exists(): + raise NotImplementedError() + +def is_(): + raise NotImplementedError() + +def isnot(): + raise NotImplementedError() + +def like_op(a, b): + return a.like(b) + +def notlike_op(a, b): + raise NotImplementedError() + +def ilike_op(a, b): + return a.ilike(b) + +def notilike_op(a, b): + raise NotImplementedError() + +def between_op(a, b): + return a.between(b) + +def in_op(a, b): + return a.in_(*b) + +def notin_op(a, b): + raise NotImplementedError() + +def distinct_op(a): + return a.distinct() + +def startswith_op(a, b): + return a.startswith(b) + +def endswith_op(a, b): + return a.endswith(b) + +def comma_op(a, b): + raise NotImplementedError() + +def concat_op(a, b): + return a.concat(b) + +def desc_op(a): + return a.desc() + +def asc_op(a): + return a.asc() + diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 30a9525f1..b48368417 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -670,14 +670,13 @@ class Mapper(object): attribute_manager.reset_class_managed(self.class_) oldinit = self.class_.__init__ - if oldinit is object.__init__: - oldinit = None + doinit = oldinit is not None and oldinit is not object.__init__ def init(instance, *args, **kwargs): self.compile() self.extension.init_instance(self, self.class_, oldinit, instance, args, kwargs) - if oldinit is not None: + if doinit: try: oldinit(instance, *args, **kwargs) except: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index bad60bb3c..2f1cd4b91 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -24,9 +24,9 @@ are less guaranteed to stay the same in future releases. """ -from sqlalchemy import util, exceptions +from sqlalchemy import util, exceptions, operators from sqlalchemy import types as sqltypes -import re, operator +import re __all__ = ['Alias', 'ClauseElement', 'ClauseParameters', 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', @@ -47,7 +47,7 @@ def desc(column): order_by = [desc(table1.mycol)] """ - return _UnaryExpression(column, modifier=ColumnOperators.desc_op) + return _UnaryExpression(column, modifier=operators.desc_op) def asc(column): """Return an ascending ``ORDER BY`` clause element. @@ -56,7 +56,7 @@ def asc(column): order_by = [asc(table1.mycol)] """ - return _UnaryExpression(column, modifier=ColumnOperators.asc_op) + return _UnaryExpression(column, modifier=operators.asc_op) def outerjoin(left, right, onclause=None, **kwargs): """Return an ``OUTER JOIN`` clause element. @@ -337,7 +337,7 @@ def and_(*clauses): """ if len(clauses) == 1: return clauses[0] - return ClauseList(operator=operator.and_, *clauses) + return ClauseList(operator=operators.and_, *clauses) def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. @@ -348,7 +348,7 @@ def or_(*clauses): if len(clauses) == 1: return clauses[0] - return ClauseList(operator=operator.or_, *clauses) + return ClauseList(operator=operators.or_, *clauses) def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. @@ -357,12 +357,12 @@ def not_(clause): subclasses to produce the same result. """ - return operator.inv(clause) + return operators.inv(clause) def distinct(expr): """return a ``DISTINCT`` clause.""" - return _UnaryExpression(expr, operator=ColumnOperators.distinct_op) + return _UnaryExpression(expr, operator=operators.distinct_op) def between(ctest, cleft, cright): """Return a ``BETWEEN`` predicate clause. @@ -374,7 +374,7 @@ def between(ctest, cleft, cright): """ ctest = _literal_as_binds(ctest) - return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op) + return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operators.and_, group=False), operators.between_op) def case(whens, value=None, else_=None): @@ -420,7 +420,7 @@ def cast(clause, totype, **kwargs): def extract(field, expr): """Return the clause ``extract(field FROM expr)``.""" - expr = _BinaryExpression(text(field), expr, Operators.from_) + expr = _BinaryExpression(text(field), expr, operators.from_) return func.extract(expr) def exists(*args, **kwargs): @@ -1195,38 +1195,18 @@ class ClauseElement(object): if hasattr(self, 'negation_clause'): return self.negation_clause else: - return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None) + return _UnaryExpression(self.self_group(against=operators.inv), operator=operators.inv, negate=None) class Operators(object): - def from_(): - raise NotImplementedError() - from_ = staticmethod(from_) - - def as_(): - raise NotImplementedError() - as_ = staticmethod(as_) - - def exists(): - raise NotImplementedError() - exists = staticmethod(exists) - - def is_(): - raise NotImplementedError() - is_ = staticmethod(is_) - - def isnot(): - raise NotImplementedError() - isnot = staticmethod(isnot) - def __and__(self, other): - return self.operate(operator.and_, other) + return self.operate(operators.and_, other) def __or__(self, other): - return self.operate(operator.or_, other) + return self.operate(operators.or_, other) def __invert__(self): - return self.operate(operator.inv) + return self.operate(operators.inv) def clause_element(self): raise NotImplementedError() @@ -1239,137 +1219,80 @@ class Operators(object): class ColumnOperators(Operators): """defines comparison and math operations""" - - def like_op(a, b): - return a.like(b) - like_op = staticmethod(like_op) - - def notlike_op(a, b): - raise NotImplementedError() - notlike_op = staticmethod(notlike_op) - - def ilike_op(a, b): - return a.ilike(b) - ilike_op = staticmethod(ilike_op) - - def notilike_op(a, b): - raise NotImplementedError() - notilike_op = staticmethod(notilike_op) - - def between_op(a, b): - return a.between(b) - between_op = staticmethod(between_op) - - def in_op(a, b): - return a.in_(*b) - in_op = staticmethod(in_op) - - def notin_op(a, b): - raise NotImplementedError() - notin_op = staticmethod(notin_op) - - def distinct_op(a): - return a.distinct() - distinct_op = staticmethod(distinct_op) - - def startswith_op(a, b): - return a.startswith(b) - startswith_op = staticmethod(startswith_op) - - def endswith_op(a, b): - return a.endswith(b) - endswith_op = staticmethod(endswith_op) - - def comma_op(a, b): - raise NotImplementedError() - comma_op = staticmethod(comma_op) - - def concat_op(a, b): - return a.concat(b) - concat_op = staticmethod(concat_op) - - def desc_op(a): - return a.desc() - desc_op = staticmethod(desc_op) - - def asc_op(a): - return a.asc() - asc_op = staticmethod(asc_op) - def __lt__(self, other): - return self.operate(operator.lt, other) + return self.operate(operators.lt, other) def __le__(self, other): - return self.operate(operator.le, other) + return self.operate(operators.le, other) def __eq__(self, other): - return self.operate(operator.eq, other) + return self.operate(operators.eq, other) def __ne__(self, other): - return self.operate(operator.ne, other) + return self.operate(operators.ne, other) def __gt__(self, other): - return self.operate(operator.gt, other) + return self.operate(operators.gt, other) def __ge__(self, other): - return self.operate(operator.ge, other) + return self.operate(operators.ge, other) def concat(self, other): - return self.operate(ColumnOperators.concat_op, other) + return self.operate(operators.concat_op, other) def like(self, other): - return self.operate(ColumnOperators.like_op, other) + return self.operate(operators.like_op, other) def in_(self, *other): - return self.operate(ColumnOperators.in_op, other) + return self.operate(operators.in_op, other) def startswith(self, other): - return self.operate(ColumnOperators.startswith_op, other) + return self.operate(operators.startswith_op, other) def endswith(self, other): - return self.operate(ColumnOperators.endswith_op, other) + return self.operate(operators.endswith_op, other) def desc(self): - return self.operate(ColumnOperators.desc_op) + return self.operate(operators.desc_op) def asc(self): - return self.operate(ColumnOperators.asc_op) + return self.operate(operators.asc_op) def __radd__(self, other): - return self.reverse_operate(operator.add, other) + return self.reverse_operate(operators.add, other) def __rsub__(self, other): - return self.reverse_operate(operator.sub, other) + return self.reverse_operate(operators.sub, other) def __rmul__(self, other): - return self.reverse_operate(operator.mul, other) + return self.reverse_operate(operators.mul, other) def __rdiv__(self, other): - return self.reverse_operate(operator.div, other) + return self.reverse_operate(operators.div, other) def between(self, cleft, cright): - return self.operate(ColumnOperators.between_op, cleft, cright) + return self.operate(operators.between_op, cleft, cright) def distinct(self): - return self.operate(ColumnOperators.distinct_op) + return self.operate(operators.distinct_op) def __add__(self, other): - return self.operate(operator.add, other) + return self.operate(operators.add, other) def __sub__(self, other): - return self.operate(operator.sub, other) + return self.operate(operators.sub, other) def __mul__(self, other): - return self.operate(operator.mul, other) + return self.operate(operators.mul, other) def __div__(self, other): - return self.operate(operator.div, other) + return self.operate(operators.div, other) def __mod__(self, other): - return self.operate(operator.mod, other) + return self.operate(operators.mod, other) def __truediv__(self, other): - return self.operate(operator.truediv, other) + return self.operate(operators.truediv, other) # precedence ordering for common operators. if an operator is not present in this list, # it will be parenthesized when grouped against other operators @@ -1377,35 +1300,35 @@ _smallest = object() _largest = object() PRECEDENCE = { - Operators.from_:15, - operator.mul:7, - operator.div:7, - operator.mod:7, - operator.add:6, - operator.sub:6, - ColumnOperators.concat_op:6, - ColumnOperators.ilike_op:5, - ColumnOperators.notilike_op:5, - ColumnOperators.like_op:5, - ColumnOperators.notlike_op:5, - ColumnOperators.in_op:5, - ColumnOperators.notin_op:5, - Operators.is_:5, - Operators.isnot:5, - operator.eq:5, - operator.ne:5, - operator.gt:5, - operator.lt:5, - operator.ge:5, - operator.le:5, - ColumnOperators.between_op:5, - ColumnOperators.distinct_op:5, - operator.inv:4, - operator.and_:3, - operator.or_:2, - ColumnOperators.comma_op:-1, - Operators.as_:-1, - Operators.exists:0, + operators.from_:15, + operators.mul:7, + operators.div:7, + operators.mod:7, + operators.add:6, + operators.sub:6, + operators.concat_op:6, + operators.ilike_op:5, + operators.notilike_op:5, + operators.like_op:5, + operators.notlike_op:5, + operators.in_op:5, + operators.notin_op:5, + operators.is_:5, + operators.isnot:5, + operators.eq:5, + operators.ne:5, + operators.gt:5, + operators.lt:5, + operators.ge:5, + operators.le:5, + operators.between_op:5, + operators.distinct_op:5, + operators.inv:4, + operators.and_:3, + operators.or_:2, + operators.comma_op:-1, + operators.as_:-1, + operators.exists:0, _smallest: -1000, _largest: 1000 } @@ -1415,10 +1338,10 @@ class _CompareMixin(ColumnOperators): def __compare(self, op, obj, negate=None): if obj is None or isinstance(obj, _Null): - if op == operator.eq: - return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot) - elif op == operator.ne: - return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_) + if op == operators.eq: + return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot) + elif op == operators.ne: + return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_) else: raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") else: @@ -1433,25 +1356,25 @@ class _CompareMixin(ColumnOperators): type_ = self._compare_type(obj) # TODO: generalize operator overloading like this out into the types module - if op == operator.add and isinstance(type_, (sqltypes.Concatenable)): - op = ColumnOperators.concat_op + if op == operators.add and isinstance(type_, (sqltypes.Concatenable)): + op = operators.concat_op return _BinaryExpression(self.expression_element(), obj, op, type_=type_) operators = { - operator.add : (__operate,), - operator.mul : (__operate,), - operator.sub : (__operate,), - operator.div : (__operate,), - operator.mod : (__operate,), - operator.truediv : (__operate,), - operator.lt : (__compare, operator.ge), - operator.le : (__compare, operator.gt), - operator.ne : (__compare, operator.eq), - operator.gt : (__compare, operator.le), - operator.ge : (__compare, operator.lt), - operator.eq : (__compare, operator.ne), - ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op), + operators.add : (__operate,), + operators.mul : (__operate,), + operators.sub : (__operate,), + operators.div : (__operate,), + operators.mod : (__operate,), + operators.truediv : (__operate,), + operators.lt : (__compare, operators.ge), + operators.le : (__compare, operators.gt), + operators.ne : (__compare, operators.eq), + operators.gt : (__compare, operators.le), + operators.ge : (__compare, operators.lt), + operators.eq : (__compare, operators.ne), + operators.like_op : (__compare, operators.notlike_op), } def operate(self, op, *other): @@ -1462,7 +1385,7 @@ class _CompareMixin(ColumnOperators): return self._bind_param(other).operate(op, self) def in_(self, *other): - return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other) + return self._in_impl(operators.in_op, operators.notin_op, *other) def _in_impl(self, op, negate_op, *other): if len(other) == 0: @@ -1489,7 +1412,7 @@ class _CompareMixin(ColumnOperators): """produce the clause ``LIKE '<other>%'``""" perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String) - return self.__compare(ColumnOperators.like_op, other + perc) + return self.__compare(operators.like_op, other + perc) def endswith(self, other): """produce the clause ``LIKE '%<other>'``""" @@ -1498,7 +1421,7 @@ class _CompareMixin(ColumnOperators): else: po = literal('%', type_=sqltypes.String) + other po.type = sqltypes.to_instance(sqltypes.String) #force! - return self.__compare(ColumnOperators.like_op, po) + return self.__compare(operators.like_op, po) def label(self, name): """produce a column label, i.e. ``<columnname> AS <name>``""" @@ -1516,12 +1439,12 @@ class _CompareMixin(ColumnOperators): def distinct(self): """produce a DISTINCT clause, i.e. ``DISTINCT <columnname>``""" - return _UnaryExpression(self, operator=ColumnOperators.distinct_op) + return _UnaryExpression(self, operator=operators.distinct_op) def between(self, cleft, cright): """produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``""" - return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op) + return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operators.and_, group=False), operators.between_op) def op(self, operator): """produce a generic operator function. @@ -2132,7 +2055,7 @@ class ClauseList(ClauseElement): def __init__(self, *clauses, **kwargs): self.clauses = [] - self.operator = kwargs.pop('operator', ColumnOperators.comma_op) + self.operator = kwargs.pop('operator', operators.comma_op) self.group = kwargs.pop('group', True) self.group_contents = kwargs.pop('group_contents', True) for c in clauses: @@ -2248,7 +2171,7 @@ class _Function(_CalculatedClause, FromClause): def __init__(self, name, *clauses, **kwargs): self.packagenames = kwargs.get('packagenames', None) or [] - kwargs['operator'] = ColumnOperators.comma_op + kwargs['operator'] = operators.comma_op _CalculatedClause.__init__(self, name, **kwargs) for c in clauses: self.append(c) @@ -2365,7 +2288,7 @@ class _BinaryExpression(ColumnElement): ( self.left.compare(other.left) and self.right.compare(other.right) or ( - self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and + self.operator in [operators.eq, operators.ne, operators.add, operators.mul] and self.left.compare(other.right) and self.right.compare(other.left) ) ) @@ -2390,7 +2313,7 @@ class _Exists(_UnaryExpression): def __init__(self, *args, **kwargs): kwargs['correlate'] = True s = select(*args, **kwargs).as_scalar().self_group() - _UnaryExpression.__init__(self, s, operator=Operators.exists) + _UnaryExpression.__init__(self, s, operator=operators.exists) def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) @@ -2444,7 +2367,7 @@ class Join(FromClause): class BinaryVisitor(ClauseVisitor): def visit_binary(self, binary): - if binary.operator == operator.eq: + if binary.operator == operators.eq: add_equiv(binary.left, binary.right) BinaryVisitor().traverse(self.onclause) @@ -2526,7 +2449,7 @@ class Join(FromClause): equivs = util.Set() class LocateEquivs(NoColumnVisitor): def visit_binary(self, binary): - if binary.operator == operator.eq and binary.left.name == binary.right.name: + if binary.operator == operators.eq and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) LocateEquivs().traverse(self.onclause) @@ -2739,7 +2662,7 @@ class _Label(ColumnElement): obj = obj.obj self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) - self.obj = obj.self_group(against=Operators.as_) + self.obj = obj.self_group(against=operators.as_) self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) key = property(lambda s: s.name) @@ -3314,7 +3237,7 @@ class Select(_SelectBaseMixin, FromClause): column = _literal_as_column(column) if isinstance(column, _ScalarSelect): - column = column.self_group(against=ColumnOperators.comma_op) + column = column.self_group(against=operators.comma_op) self._raw_columns.append(column) diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py index e80f8ab05..d4b4c249a 100644 --- a/test/orm/sharding/shard.py +++ b/test/orm/sharding/shard.py @@ -3,8 +3,8 @@ from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.shard import ShardedSession -from sqlalchemy.sql import ColumnOperators -import datetime, operator, os +from sqlalchemy.sql import operators +import datetime, os from testlib import PersistTest # TODO: ShardTest can be turned into a base for further subclasses @@ -81,9 +81,9 @@ class ShardTest(PersistTest): class FindContinent(sql.ClauseVisitor): def visit_binary(self, binary): if binary.left is weather_locations.c.continent: - if binary.operator == operator.eq: + if binary.operator == operators.eq: ids.append(shard_lookup[binary.right.value]) - elif binary.operator == ColumnOperators.in_op: + elif binary.operator == operators.in_op: for bind in binary.right.clauses: ids.append(shard_lookup[bind.value]) diff --git a/test/sql/select.py b/test/sql/select.py index 865f1ec48..c075d4e3b 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -1,6 +1,7 @@ import testbase import re, operator from sqlalchemy import * +from sqlalchemy import util from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql from testlib import * @@ -406,6 +407,10 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A table1.select(table1.c.myid.op('hoho')(12)==14), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (mytable.myid hoho :mytable_myid) = :literal" ) + + # test that clauses can be pickled (operators need to be module-level, etc.) + clause = (table1.c.myid == 12) & table1.c.myid.between(15, 20) & table1.c.myid.like('hoho') + assert str(clause) == str(util.pickle.loads(util.pickle.dumps(clause))) def testunicodestartswith(self): string = u"hi \xf6 \xf5" |