diff options
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 410 |
1 files changed, 223 insertions, 187 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 5adef46f2..e27181a9a 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -34,11 +34,30 @@ __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters', 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join', 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc', 'between_', 'bindparam', 'case', 'cast', 'column', 'delete', - 'desc', 'except_', 'except_all', 'exists', 'extract', 'func', + 'desc', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', 'insert', 'intersect', 'intersect_all', 'join', 'literal', 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select', 'subquery', 'table', 'text', 'union', 'union_all', 'update',] +# precedence ordering for common operators. if an operator is not present in this list, +# its precedence is assumed to be '0' which will cause it to be parenthesized when grouped against other operators +PRECEDENCE = { + 'FROM':15, + 'AS':15, + 'NOT':10, + 'AND':3, + 'OR':3, + '=':7, + '!=':7, + '>':7, + '<':7, + '+':5, + '-':5, + '*':5, + '/':5, + ',':0 +} + def desc(column): """Return a descending ``ORDER BY`` clause element. @@ -46,7 +65,7 @@ def desc(column): order_by = [desc(table1.mycol)] """ - return _CompoundClause(None, column, "DESC") + return _UnaryExpression(column, modifier="DESC") def asc(column): """Return an ascending ``ORDER BY`` clause element. @@ -55,7 +74,7 @@ def asc(column): order_by = [asc(table1.mycol)] """ - return _CompoundClause(None, column, "ASC") + return _UnaryExpression(column, modifier="ASC") def outerjoin(left, right, onclause=None, **kwargs): """Return an ``OUTER JOIN`` clause element. @@ -332,8 +351,9 @@ def and_(*clauses): The ``&`` operator is also overloaded on all [sqlalchemy.sql#_CompareMixin] subclasses to produce the same result. """ - - return _compound_clause('AND', *clauses) + if len(clauses) == 1: + return clauses[0] + return ClauseList(operator='AND', *clauses) def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. @@ -342,7 +362,9 @@ def or_(*clauses): subclasses to produce the same result. """ - return _compound_clause('OR', *clauses) + if len(clauses) == 1: + return clauses[0] + return ClauseList(operator='OR', *clauses) def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. @@ -362,7 +384,7 @@ def between(ctest, cleft, cright): provides similar functionality. """ - return _BooleanExpression(ctest, and_(_check_literal(cleft, ctest.type), _check_literal(cright, ctest.type)), 'BETWEEN') + return _BinaryExpression(ctest, and_(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type)), 'BETWEEN') def between_(*args, **kwargs): """synonym for [sqlalchemy.sql#between()] (deprecated).""" @@ -383,16 +405,14 @@ def case(whens, value=None, else_=None): """ - whenlist = [_CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens] + whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None) for (c,r) in whens] if not else_ is None: - whenlist.append(_CompoundClause(None, 'ELSE', else_)) + whenlist.append(ClauseList('ELSE', else_, operator=None)) if len(whenlist): type = list(whenlist[-1])[-1].type else: type = None - cc = _CalculatedClause(None, 'CASE', value, type=type, *whenlist + ['END']) - for c in cc.clauses: - c.parens = False + cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END']) return cc def cast(clause, totype, **kwargs): @@ -414,7 +434,7 @@ def cast(clause, totype, **kwargs): def extract(field, expr): """Return the clause ``extract(field FROM expr)``.""" - expr = _BinaryClause(text(field), expr, "FROM") + expr = _BinaryExpression(text(field), expr, "FROM") return func.extract(expr) def exists(*args, **kwargs): @@ -543,11 +563,6 @@ def alias(selectable, alias=None): return Alias(selectable, alias=alias) -def _check_literal(value, type): - if _is_literal(value): - return literal(value, type) - else: - return value def literal(value, type=None): """Return a literal clause, bound to a bind parameter. @@ -714,20 +729,27 @@ def null(): return _Null() -class _FunctionGateway(object): - """Return a callable based on an attribute name, which then - returns a ``_Function`` object with that name. - """ +class _FunctionGenerator(object): + """Generate ``_Function`` objects based on getattr calls.""" + + def __init__(self, **opts): + self.__names = [] + self.opts = opts def __getattr__(self, name): if name[-1] == '_': name = name[0:-1] - return getattr(_FunctionGenerator(), name) + f = _FunctionGenerator(**self.opts) + f.__names = list(self.__names) + [name] + return f -func = _FunctionGateway() + def __call__(self, *c, **kwargs): + o = self.opts.copy() + o.update(kwargs) + return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) -def _compound_clause(keyword, *clauses): - return _CompoundClause(keyword, *clauses) +func = _FunctionGenerator() +modifier = _FunctionGenerator(group=False) def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) @@ -735,6 +757,21 @@ def _compound_select(keyword, *selects, **kwargs): def _is_literal(element): return not isinstance(element, ClauseElement) +def _literals_as_text(element): + if _is_literal(element): + return _TextClause(unicode(element)) + else: + return element + +def _literals_as_binds(element, name='literal', type=None): + if _is_literal(element): + if element is None: + return null() + else: + return _BindParamClause(name, element, shortname=name, type=type, unique=True) + else: + return element + def is_column(col): return isinstance(col, ColumnElement) @@ -825,13 +862,23 @@ class ClauseVisitor(object): (column_collections=False) or to return Schema-level items (schema_visitor=True).""" __traverse_options__ = {} - def traverse(self, obj): - for n in obj.get_children(**self.__traverse_options__): - self.traverse(n) - v = self - while v is not None: - obj.accept_visitor(v) - v = getattr(v, '_next', None) + def traverse(self, obj, stop_on=None, echo=False): + stack = [obj] + traversal = [] + while len(stack) > 0: + t = stack.pop() + if stop_on is None or t not in stop_on: + traversal.insert(0, t) + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + for target in traversal: + v = self + if echo: + print "VISITING", repr(target), "STOP ON", stop_on + while v is not None: + target.accept_visitor(v) + v = getattr(v, '_next', None) + return obj def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. @@ -859,6 +906,8 @@ class ClauseVisitor(object): pass def visit_binary(self, binary): pass + def visit_unary(self, unary): + pass def visit_alias(self, alias): pass def visit_select(self, select): @@ -871,6 +920,8 @@ class ClauseVisitor(object): pass def visit_calculatedclause(self, calcclause): pass + def visit_grouping(self, gr): + pass def visit_function(self, func): pass def visit_cast(self, cast): @@ -1060,7 +1111,10 @@ class ClauseElement(object): child items from a different context (such as schema-level collections instead of clause-level).""" return [] - + + def self_group(self, against=None): + return self + def supports_execution(self): """Return True if this clause element represents a complete executable statement. @@ -1175,8 +1229,7 @@ class ClauseElement(object): return self._negate() def _negate(self): - self.parens=True - return _BooleanExpression(_TextClause("NOT"), self, None) + return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None) class _CompareMixin(object): """Defines comparison operations for ``ClauseElement`` instances. @@ -1237,7 +1290,7 @@ class _CompareMixin(object): else: o = self._bind_param(o) args.append(o) - return self._compare( 'IN', ClauseList( parens=True, *args), negate='NOT IN') + return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN') def startswith(self, other): """produce the clause ``LIKE '<other>%'``""" @@ -1267,11 +1320,11 @@ class _CompareMixin(object): def distinct(self): """produce a DISTINCT clause, i.e. ``DISTINCT <columnname>``""" - return _CompoundClause(None,"DISTINCT", self) + return _UnaryExpression(self, operator="DISTINCT") def between(self, cleft, cright): """produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``""" - return _BooleanExpression(self, and_(self._check_literal(cleft), self._check_literal(cright)), 'BETWEEN') + return _BinaryExpression(self, and_(self._check_literal(cleft), self._check_literal(cright)), 'BETWEEN') def op(self, operator): """produce a generic operator function. @@ -1324,15 +1377,15 @@ class _CompareMixin(object): def _compare(self, operator, obj, negate=None): if obj is None or isinstance(obj, _Null): if operator == '=': - return _BooleanExpression(self._compare_self(), null(), 'IS', negate='IS NOT') + return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT') elif operator == '!=': - return _BooleanExpression(self._compare_self(), null(), 'IS NOT', negate='IS') + return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS') else: raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") else: obj = self._check_literal(obj) - return _BooleanExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate) + return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate) def _operate(self, operator, obj): if _is_literal(obj): @@ -1348,7 +1401,7 @@ class _CompareMixin(object): def _compare_type(self, obj): """Allow subclasses to override the type used in constructing - ``_BinaryClause`` objects. + ``_BinaryExpression`` objects. Default return value is the type of the given object. """ @@ -1384,6 +1437,7 @@ class Selectable(ClauseElement): return True + class ColumnElement(Selectable, _CompareMixin): """Represent an element that is useable within the "column clause" portion of a ``SELECT`` statement. @@ -1789,7 +1843,6 @@ class _TextClause(ClauseElement): """ def __init__(self, text = "", engine=None, bindparams=None, typemap=None): - self.parens = False self._engine = engine self.bindparams = {} self.typemap = typemap @@ -1845,29 +1898,42 @@ class _Null(ColumnElement): return [] class ClauseList(ClauseElement): - """Describe a list of clauses. + """Describe a list of clauses, separated by an operator. By default, is comma-separated, such as a column listing. """ def __init__(self, *clauses, **kwargs): self.clauses = [] + self.operator = kwargs.pop('operator', ',') + self.group = kwargs.pop('group', True) + self.group_contents = kwargs.pop('group_contents', True) for c in clauses: if c is None: continue self.append(c) - self.parens = kwargs.get('parens', False) def __iter__(self): return iter(self.clauses) - + def __len__(self): + return len(self.clauses) + def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] - return ClauseList(parens=self.parens, *clauses) + return ClauseList(operator=self.operator, *clauses) + + def self_group(self, against=None): + if self.group: + return _Grouping(self) + else: + return self def append(self, clause): - if _is_literal(clause): - clause = _TextClause(unicode(clause)) - self.clauses.append(clause) + # TODO: not sure if i like the 'group_contents' flag. need to define the difference between + # a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ? + if self.group_contents: + self.clauses.append(_literals_as_text(clause).self_group(against=self.operator)) + else: + self.clauses.append(_literals_as_text(clause)) def get_children(self, **kwargs): return self.clauses @@ -1881,6 +1947,12 @@ class ClauseList(ClauseElement): f += c._get_from_objects() return f + def self_group(self, against=None): + if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0): + return _Grouping(self) + else: + return self + def compare(self, other): """Compare this ``ClauseList`` to the given ``ClauseList``, including a comparison of all the clause items. @@ -1891,59 +1963,11 @@ class ClauseList(ClauseElement): if not self.clauses[i].compare(other.clauses[i]): return False else: - return True - else: - return False - -class _CompoundClause(ClauseList): - """Represent a list of clauses joined by an operator, such as ``AND`` or ``OR``. - - Extends ``ClauseList`` to add the operator as well as a - `from_objects` accessor to help determine ``FROM`` objects in a - ``SELECT`` statement. - """ - - def __init__(self, operator, *clauses, **kwargs): - ClauseList.__init__(self, *clauses, **kwargs) - self.operator = operator - - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return _CompoundClause(self.operator, *clauses) - - def append(self, clause): - if isinstance(clause, _CompoundClause): - clause.parens = True - ClauseList.append(self, clause) - - def get_children(self, **kwargs): - return self.clauses - def accept_visitor(self, visitor): - visitor.visit_compound(self) - - def _get_from_objects(self): - f = [] - for c in self.clauses: - f += c._get_from_objects() - return f - - def compare(self, other): - """Compare this ``_CompoundClause`` to the given item. - - In addition to the regular comparison, has the special case - that it returns True if this ``_CompoundClause`` has only one - item, and that item matches the given item. - """ - - if not isinstance(other, _CompoundClause): - if len(self.clauses) == 1: - return self.clauses[0].compare(other) - if ClauseList.compare(self, other): - return self.operator == other.operator + return self.operator == other.operator else: return False -class _CalculatedClause(ClauseList, ColumnElement): +class _CalculatedClause(ColumnElement): """Describe a calculated SQL expression that has a type, like ``CASE``. Extends ``ColumnElement`` to provide column-level comparison @@ -1954,8 +1978,13 @@ class _CalculatedClause(ClauseList, ColumnElement): self.name = name self.type = sqltypes.to_instance(kwargs.get('type', None)) self._engine = kwargs.get('engine', None) - ClauseList.__init__(self, *clauses) - + self.group = kwargs.pop('group', True) + self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) + if self.group: + self.clause_expr = self.clauses.self_group() + else: + self.clause_expr = self.clauses + key = property(lambda self:self.name or "_calc_") def copy_container(self): @@ -1963,9 +1992,12 @@ class _CalculatedClause(ClauseList, ColumnElement): return _CalculatedClause(type=self.type, engine=self._engine, *clauses) def get_children(self, **kwargs): - return self.clauses + return self.clause_expr, + def accept_visitor(self, visitor): visitor.visit_calculatedclause(self) + def _get_from_objects(self): + return self.clauses._get_from_objects() def _bind_param(self, obj): return _BindParamClause(self.name, obj, type=self.type, unique=True) @@ -1990,28 +2022,24 @@ class _Function(_CalculatedClause, FromClause): """ def __init__(self, name, *clauses, **kwargs): - self.name = name self.type = sqltypes.to_instance(kwargs.get('type', None)) self.packagenames = kwargs.get('packagenames', None) or [] + kwargs['operator'] = ',' self._engine = kwargs.get('engine', None) - ClauseList.__init__(self, parens=True, *[c is None and _Null() or c for c in clauses]) + _CalculatedClause.__init__(self, name, **kwargs) + for c in clauses: + self.append(c) key = property(lambda self:self.name) + def append(self, clause): - if _is_literal(clause): - if clause is None: - clause = null() - else: - clause = _BindParamClause(self.name, clause, shortname=self.name, type=None, unique=True) - self.clauses.append(clause) + self.clauses.append(_literals_as_binds(clause, self.name)) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses) - - def get_children(self, **kwargs): - return self.clauses + def accept_visitor(self, visitor): visitor.visit_function(self) @@ -2040,39 +2068,52 @@ class _Cast(ColumnElement): else: return self -class _FunctionGenerator(object): - """Generate ``_Function`` objects based on getattr calls.""" - def __init__(self, engine=None): - self.__engine = engine - self.__names = [] +class _UnaryExpression(ColumnElement): + def __init__(self, element, operator=None, modifier=None, type=None, negate=None): + self.operator = operator + self.modifier = modifier + + self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier) + self.type = sqltypes.to_instance(type) + self.negate = negate + + def copy_container(self): + return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate) - def __getattr__(self, name): - self.__names.append(name) - return self + def _get_from_objects(self): + return self.element._get_from_objects() - def __call__(self, *c, **kwargs): - kwargs.setdefault('engine', self.__engine) - return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **kwargs) + def get_children(self, **kwargs): + return self.element, -class _BinaryClause(ClauseElement): - """Represent two clauses with an operator in between. - - This class serves as the base class for ``_BinaryExpression`` - and ``_BooleanExpression``, both of which add additional - semantics to the base ``_BinaryClause`` construct. - """ + def accept_visitor(self, visitor): + visitor.visit_unary(self) + + def compare(self, other): + """Compare this ``_UnaryClause`` against the given ``ClauseElement``.""" - def __init__(self, left, right, operator, type=None): - self.left = left - self.right = right + return ( + isinstance(other, _UnaryClause) and self.operator == other.operator and + self.modifier == other.modifier and + self.element.compare(other.element) + ) + def _negate(self): + if self.negate is not None: + return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type) + else: + return super(_UnaryExpression, self)._negate() + + +class _BinaryExpression(ColumnElement): + """Represent an expression that is ``LEFT <operator> RIGHT``.""" + + def __init__(self, left, right, operator, type=None, negate=None): + self.left = _literals_as_text(left).self_group(against=operator) + self.right = _literals_as_text(right).self_group(against=operator) self.operator = operator self.type = sqltypes.to_instance(type) - self.parens = False - if isinstance(self.left, _BinaryClause) or hasattr(self.left, '_selectable'): - self.left.parens = True - if isinstance(self.right, _BinaryClause) or hasattr(self.right, '_selectable'): - self.right.parens = True + self.negate = negate def copy_container(self): return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator) @@ -2086,58 +2127,31 @@ class _BinaryClause(ClauseElement): def accept_visitor(self, visitor): visitor.visit_binary(self) - def swap(self): - c = self.left - self.left = self.right - self.right = c - def compare(self, other): - """Compare this ``_BinaryClause`` against the given ``_BinaryClause``.""" + """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" return ( - isinstance(other, _BinaryClause) and self.operator == other.operator and + isinstance(other, _BinaryExpression) and self.operator == other.operator and self.left.compare(other.left) and self.right.compare(other.right) ) - -class _BinaryExpression(_BinaryClause, ColumnElement): - """Represent a binary expression, which can be in a ``WHERE`` - criterion or in the column list of a ``SELECT``. - - This class differs from ``_BinaryClause`` in that it mixes - in ``ColumnElement``. The effect is that elements of this - type become ``Selectable`` units which can be placed in the - column list of a ``select()`` construct. - - """ - - pass - -class _BooleanExpression(_BinaryExpression): - """Represent a boolean expression. - - ``_BooleanExpression`` is constructed as the result of compare operations - involving ``CompareMixin`` subclasses, such as when comparing a ``ColumnElement`` - to a scalar value via the ``==`` operator, ``CompareMixin``'s ``__eq__()`` method - produces a ``_BooleanExpression`` consisting of the ``ColumnElement`` and a - ``_BindParamClause``. + + def self_group(self, against=None): + if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0): + return _Grouping(self) + else: + return self - """ - - def __init__(self, *args, **kwargs): - self.negate = kwargs.pop('negate', None) - super(_BooleanExpression, self).__init__(*args, **kwargs) - def _negate(self): if self.negate is not None: - return _BooleanExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type) + return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type) else: - return super(_BooleanExpression, self)._negate() + return super(_BinaryExpression, self)._negate() -class _Exists(_BooleanExpression): +class _Exists(_UnaryExpression): def __init__(self, *args, **kwargs): kwargs['correlate'] = True - s = select(*args, **kwargs) - _BooleanExpression.__init__(self, _TextClause("EXISTS"), s, None) + s = select(*args, **kwargs).self_group() + _UnaryExpression.__init__(self, s, operator="EXISTS") def _hide_froms(self): return self._get_from_objects() @@ -2358,6 +2372,27 @@ class Alias(FromClause): engine = property(lambda s: s.selectable.engine) +class _Grouping(ColumnElement): + def __init__(self, elem): + self.elem = elem + self.type = getattr(elem, 'type', None) + + key = property(lambda s: s.elem.key) + _label = property(lambda s: s.elem._label) + orig_set = property(lambda s:s.elem.orig_set) + + def copy_container(self): + return _Grouping(self.elem.copy_container()) + + def accept_visitor(self, visitor): + visitor.visit_grouping(self) + def get_children(self, **kwargs): + return self.elem, + def _hide_froms(self): + return self.elem._hide_froms() + def _get_from_objects(self): + return self.elem._get_from_objects() + class _Label(ColumnElement): """represent a label, as typically applied to any column-level element using the ``AS`` sql keyword. @@ -2372,10 +2407,9 @@ class _Label(ColumnElement): self.name = name while isinstance(obj, _Label): obj = obj.obj - self.obj = obj + self.obj = obj.self_group(against='AS') self.case_sensitive = getattr(obj, "case_sensitive", True) self.type = sqltypes.to_instance(type or getattr(obj, 'type', None)) - obj.parens=True key = property(lambda s: s.name) _label = property(lambda s: s.name) @@ -2633,7 +2667,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause): _SelectBaseMixin.__init__(self) self.keyword = keyword self.use_labels = kwargs.pop('use_labels', False) - self.parens = kwargs.pop('parens', False) self.should_correlate = kwargs.pop('correlate', False) self.for_update = kwargs.pop('for_update', False) self.nowait = kwargs.pop('nowait', False) @@ -2794,8 +2827,6 @@ class Select(_SelectBaseMixin, FromClause): def visit_compound_select(self, cs): self.visit_select(cs) - for s in cs.selects: - s.parens = False def visit_column(self, c): pass @@ -2808,7 +2839,6 @@ class Select(_SelectBaseMixin, FromClause): return select.is_where = self.is_where select.is_subquery = True - select.parens = True if not select.should_correlate: return [select.correlate(x) for x in self.select._Select__froms] @@ -2828,6 +2858,9 @@ class Select(_SelectBaseMixin, FromClause): if _is_literal(column): column = literal_column(str(column)) + if isinstance(column, Select) and column.is_scalar: + column = column.self_group(against=',') + self._raw_columns.append(column) if self.is_scalar and not hasattr(self, 'type'): @@ -2873,6 +2906,9 @@ class Select(_SelectBaseMixin, FromClause): for f in elem._hide_froms(): self.__hide_froms.add(f) + def self_group(self, against=None): + return _Grouping(self) + def append_whereclause(self, whereclause): self._append_condition('whereclause', whereclause) |