diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
commit | ed4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch) | |
tree | c1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/sql.py | |
parent | 3a8e235af64e36b3b711df1f069d32359fe6c967 (diff) | |
download | sqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz |
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 1946 |
1 files changed, 1064 insertions, 882 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 8b454947e..01588e92d 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -24,54 +24,21 @@ are less guaranteed to stay the same in future releases. """ -from sqlalchemy import util, exceptions, logging +from sqlalchemy import util, exceptions from sqlalchemy import types as sqltypes -import string, re, random, sets +import re, operator - -__all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters', +__all__ = ['Alias', 'ClauseElement', 'ClauseParameters', 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', - 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join', - 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc', - 'between_', 'between', 'bindparam', 'case', 'cast', 'column', 'delete', + 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', + 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', + 'between', 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', 'insert', 'intersect', 'intersect_all', 'join', 'literal', - 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select', + 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select', 'subquery', 'table', 'text', 'union', 'union_all', 'update',] -# precedence ordering for common operators. if an operator is not present in this list, -# it will be parenthesized when grouped against other operators -PRECEDENCE = { - 'FROM':15, - '*':7, - '/':7, - '%':7, - '+':6, - '-':6, - 'ILIKE':5, - 'NOT ILIKE':5, - 'LIKE':5, - 'NOT LIKE':5, - 'IN':5, - 'NOT IN':5, - 'IS':5, - 'IS NOT':5, - '=':5, - '!=':5, - '>':5, - '<':5, - '>=':5, - '<=':5, - 'BETWEEN':5, - 'NOT':4, - 'AND':3, - 'OR':2, - ',':-1, - 'AS':-1, - 'EXISTS':0, - '_smallest': -1000, - '_largest': 1000 -} +BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE) def desc(column): """Return a descending ``ORDER BY`` clause element. @@ -141,7 +108,7 @@ def join(left, right, onclause=None, **kwargs): return Join(left, right, onclause, **kwargs) -def select(columns=None, whereclause = None, from_obj = [], **kwargs): +def select(columns=None, whereclause=None, from_obj=[], **kwargs): """Returns a ``SELECT`` clause element. Similar functionality is also available via the ``select()`` method on any @@ -224,9 +191,6 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs): automatically bind to whatever ``Connectable`` instances can be located within its contained ``ClauseElement`` members. - engine=None - deprecated. a synonym for "bind". - limit=None a numerical value which usually compiles to a ``LIMIT`` expression in the resulting select. Databases that don't support ``LIMIT`` @@ -238,12 +202,8 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs): will attempt to provide similar functionality. scalar=False - when ``True``, indicates that the resulting ``Select`` object - is to be used in the "columns" clause of another select statement, - where the evaluated value of the column is the scalar result of - this statement. Normally, placing any ``Selectable`` within the - columns clause of a ``select()`` call will expand the member - columns of the ``Selectable`` individually. + deprecated. use select(...).as_scalar() to create a "scalar column" + proxy for an existing Select object. correlate=True indicates that this ``Select`` object should have its contained @@ -254,8 +214,12 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs): rendered in the ``FROM`` clause of this select statement. """ - - return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs) + scalar = kwargs.pop('scalar', False) + s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) + if scalar: + return s.as_scalar() + else: + return s def subquery(alias, *args, **kwargs): """Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select]. @@ -271,7 +235,7 @@ def subquery(alias, *args, **kwargs): return Select(*args, **kwargs).alias(alias) def insert(table, values = None, **kwargs): - """Return an [sqlalchemy.sql#_Insert] clause element. + """Return an [sqlalchemy.sql#Insert] clause element. Similar functionality is available via the ``insert()`` method on [sqlalchemy.schema#Table]. @@ -304,10 +268,10 @@ def insert(table, values = None, **kwargs): against the ``INSERT`` statement. """ - return _Insert(table, values, **kwargs) + return Insert(table, values, **kwargs) def update(table, whereclause = None, values = None, **kwargs): - """Return an [sqlalchemy.sql#_Update] clause element. + """Return an [sqlalchemy.sql#Update] clause element. Similar functionality is available via the ``update()`` method on [sqlalchemy.schema#Table]. @@ -344,10 +308,10 @@ def update(table, whereclause = None, values = None, **kwargs): against the ``UPDATE`` statement. """ - return _Update(table, whereclause, values, **kwargs) + return Update(table, whereclause, values, **kwargs) def delete(table, whereclause = None, **kwargs): - """Return a [sqlalchemy.sql#_Delete] clause element. + """Return a [sqlalchemy.sql#Delete] clause element. Similar functionality is available via the ``delete()`` method on [sqlalchemy.schema#Table]. @@ -361,7 +325,7 @@ def delete(table, whereclause = None, **kwargs): """ - return _Delete(table, whereclause, **kwargs) + return Delete(table, whereclause, **kwargs) def and_(*clauses): """Join a list of clauses together using the ``AND`` operator. @@ -371,7 +335,7 @@ def and_(*clauses): """ if len(clauses) == 1: return clauses[0] - return ClauseList(operator='AND', *clauses) + return ClauseList(operator=operator.and_, *clauses) def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. @@ -382,7 +346,7 @@ def or_(*clauses): if len(clauses) == 1: return clauses[0] - return ClauseList(operator='OR', *clauses) + return ClauseList(operator=operator.or_, *clauses) def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. @@ -391,7 +355,7 @@ def not_(clause): subclasses to produce the same result. """ - return clause._negate() + return operator.inv(clause) def distinct(expr): """return a ``DISTINCT`` clause.""" @@ -407,12 +371,9 @@ def between(ctest, cleft, cright): provides similar functionality. """ - return _BinaryExpression(ctest, ClauseList(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type), operator='AND', group=False), 'BETWEEN') + ctest = _literal_as_binds(ctest) + return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op) -def between_(*args, **kwargs): - """synonym for [sqlalchemy.sql#between()] (deprecated).""" - - return between(*args, **kwargs) def case(whens, value=None, else_=None): """Produce a ``CASE`` statement. @@ -435,7 +396,7 @@ def case(whens, value=None, else_=None): type = list(whenlist[-1])[-1].type else: type = None - cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END']) + cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END']) return cc def cast(clause, totype, **kwargs): @@ -457,7 +418,7 @@ def cast(clause, totype, **kwargs): def extract(field, expr): """Return the clause ``extract(field FROM expr)``.""" - expr = _BinaryExpression(text(field), expr, "FROM") + expr = _BinaryExpression(text(field), expr, Operators.from_) return func.extract(expr) def exists(*args, **kwargs): @@ -587,7 +548,7 @@ def alias(selectable, alias=None): return Alias(selectable, alias=alias) -def literal(value, type=None): +def literal(value, type_=None): """Return a literal clause, bound to a bind parameter. Literal clauses are created automatically when non- @@ -603,13 +564,13 @@ def literal(value, type=None): the underlying DBAPI, or is translatable via the given type argument. - type + type\_ an optional [sqlalchemy.types#TypeEngine] which will provide bind-parameter translation for this literal. """ - return _BindParamClause('literal', value, type=type, unique=True) + return _BindParamClause('literal', value, type_=type_, unique=True) def label(name, obj): """Return a [sqlalchemy.sql#_Label] object for the given [sqlalchemy.sql#ColumnElement]. @@ -630,7 +591,7 @@ def label(name, obj): return _Label(name, obj) -def column(text, type=None): +def column(text, type_=None): """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. @@ -644,15 +605,15 @@ def column(text, type=None): constructs that are not to be quoted, use the [sqlalchemy.sql#literal_column()] function. - type + type\_ an optional [sqlalchemy.types#TypeEngine] object which will provide result-set translation for this column. """ - return _ColumnClause(text, type=type) + return _ColumnClause(text, type_=type_) -def literal_column(text, type=None): +def literal_column(text, type_=None): """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. @@ -674,7 +635,7 @@ def literal_column(text, type=None): """ - return _ColumnClause(text, type=type, is_literal=True) + return _ColumnClause(text, type_=type_, is_literal=True) def table(name, *columns): """Return a [sqlalchemy.sql#Table] object. @@ -685,7 +646,7 @@ def table(name, *columns): return TableClause(name, *columns) -def bindparam(key, value=None, type=None, shortname=None, unique=False): +def bindparam(key, value=None, type_=None, shortname=None, unique=False): """Create a bind parameter clause with the given key. value @@ -707,11 +668,22 @@ def bindparam(key, value=None, type=None, shortname=None, unique=False): """ if isinstance(key, _ColumnClause): - return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique) + return _BindParamClause(key.name, value, type_=key.type, shortname=shortname, unique=unique) else: - return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique) + return _BindParamClause(key, value, type_=type_, shortname=shortname, unique=unique) -def text(text, bind=None, engine=None, *args, **kwargs): +def outparam(key, type_=None): + """create an 'OUT' parameter for usage in functions (stored procedures), for databases + whith support them. + + The ``outparam`` can be used like a regular function parameter. The "output" value will + be available from the [sqlalchemy.engine#ResultProxy] object via its ``out_parameters`` + attribute, which returns a dictionary containing the values. + """ + + return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True) + +def text(text, bind=None, *args, **kwargs): """Create literal text to be inserted into a query. When constructing a query from a ``select()``, ``update()``, @@ -729,9 +701,6 @@ def text(text, bind=None, engine=None, *args, **kwargs): bind An optional connection or engine to be used for this text query. - engine - deprecated. a synonym for 'bind'. - bindparams A list of ``bindparam()`` instances which can be used to define the types and/or initial values for the bind parameters within @@ -748,7 +717,7 @@ def text(text, bind=None, engine=None, *args, **kwargs): """ - return _TextClause(text, engine=engine, bind=bind, *args, **kwargs) + return _TextClause(text, bind=bind, *args, **kwargs) def null(): """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement.""" @@ -786,30 +755,44 @@ def _compound_select(keyword, *selects, **kwargs): def _is_literal(element): return not isinstance(element, ClauseElement) -def _literals_as_text(element): - if _is_literal(element): +def _literal_as_text(element): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): return _TextClause(unicode(element)) else: return element -def _literals_as_binds(element, name='literal', type=None): - if _is_literal(element): +def _literal_as_column(element): + if isinstance(element, Operators): + return element.clause_element() + elif _is_literal(element): + return literal_column(str(element)) + else: + return element + +def _literal_as_binds(element, name='literal', type_=None): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): if element is None: return null() else: - return _BindParamClause(name, element, shortname=name, type=type, unique=True) + return _BindParamClause(name, element, shortname=name, type_=type_, unique=True) else: return element + +def _selectable(element): + if hasattr(element, '__selectable__'): + return element.__selectable__() + elif isinstance(element, Selectable): + return element + else: + raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) def is_column(col): return isinstance(col, ColumnElement) -class AbstractDialect(object): - """Represent the behavior of a particular database. - - Used by ``Compiled`` objects.""" - pass - class ClauseParameters(object): """Represent a dictionary/iterator of bind parameter key names/values. @@ -822,52 +805,51 @@ class ClauseParameters(object): def __init__(self, dialect, positional=None): super(ClauseParameters, self).__init__() self.dialect = dialect - self.binds = {} - self.binds_to_names = {} - self.binds_to_values = {} + self.__binds = {} self.positional = positional or [] + def get_parameter(self, key): + return self.__binds[key] + def set_parameter(self, bindparam, value, name): - self.binds[bindparam.key] = bindparam - self.binds[name] = bindparam - self.binds_to_names[bindparam] = name - self.binds_to_values[bindparam] = value + self.__binds[name] = [bindparam, name, value] def get_original(self, key): - """Return the given parameter as it was originally placed in - this ``ClauseParameters`` object, without any ``Type`` - conversion.""" - return self.binds_to_values[self.binds[key]] + return self.__binds[key][2] + + def get_type(self, key): + return self.__binds[key][0].type def get_processed(self, key): - bind = self.binds[key] - value = self.binds_to_values[bind] + (bind, name, value) = self.__binds[key] return bind.typeprocess(value, self.dialect) def keys(self): - return self.binds_to_names.values() + return self.__binds.keys() + + def __iter__(self): + return iter(self.keys()) def __getitem__(self, key): return self.get_processed(key) def __contains__(self, key): - return key in self.binds + return key in self.__binds def set_value(self, key, value): - bind = self.binds[key] - self.binds_to_values[bind] = value + self.__binds[key][2] = value def get_original_dict(self): - return dict([(self.binds_to_names[b], self.binds_to_values[b]) for b in self.binds_to_names.keys()]) + return dict([(name, value) for (b, name, value) in self.__binds.values()]) def get_raw_list(self): return [self.get_processed(key) for key in self.positional] - def get_raw_dict(self): - d = {} - for k in self.binds_to_names.values(): - d[k] = self.get_processed(k) - return d + def get_raw_dict(self, encode_keys=False): + if encode_keys: + return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()]) + else: + return dict([(key, self.get_processed(key)) for key in self.keys()]) def __repr__(self): return self.__class__.__name__ + ":" + repr(self.get_original_dict()) @@ -876,8 +858,8 @@ class ClauseVisitor(object): """A class that knows how to traverse and visit ``ClauseElements``. - Each ``ClauseElement``'s accept_visitor() method will call a - corresponding visit_XXXX() method here. Traversal of a + Calls visit_XXX() methods dynamically generated for each particualr + ``ClauseElement`` subclass encountered. Traversal of a hierarchy of ``ClauseElements`` is achieved via the ``traverse()`` method, which is passed the lead ``ClauseElement``. @@ -889,22 +871,44 @@ class ClauseVisitor(object): these options can indicate modifications to the set of elements returned, such as to not return column collections (column_collections=False) or to return Schema-level items - (schema_visitor=True).""" + (schema_visitor=True). + + ``ClauseVisitor`` also supports a simultaneous copy-and-traverse + operation, which will produce a copy of a given ``ClauseElement`` + structure while at the same time allowing ``ClauseVisitor`` subclasses + to modify the new structure in-place. + + """ __traverse_options__ = {} - def traverse(self, obj, stop_on=None): - stack = [obj] - traversal = [] - while len(stack) > 0: - t = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, t) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - for target in traversal: - v = self - while v is not None: - target.accept_visitor(v) - v = getattr(v, '_next', None) + + def traverse_single(self, obj, **kwargs): + meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kwargs) + + def traverse(self, obj, stop_on=None, clone=False): + if clone: + obj = obj._clone() + + v = self + visitors = [] + while v is not None: + visitors.append(v) + v = getattr(v, '_next', None) + + def _trav(obj): + if stop_on is not None and obj in stop_on: + return + if clone: + obj._copy_internals() + for c in obj.get_children(**self.__traverse_options__): + _trav(c) + + for v in visitors: + meth = getattr(v, "visit_%s" % obj.__visit_name__, None) + if meth: + meth(obj) + _trav(obj) return obj def chain(self, visitor): @@ -916,78 +920,6 @@ class ClauseVisitor(object): tail = tail._next tail._next = visitor return self - - def visit_column(self, column): - pass - def visit_table(self, table): - pass - def visit_fromclause(self, fromclause): - pass - def visit_bindparam(self, bindparam): - pass - def visit_textclause(self, textclause): - pass - def visit_compound(self, compound): - pass - def visit_compound_select(self, compound): - pass - def visit_binary(self, binary): - pass - def visit_unary(self, unary): - pass - def visit_alias(self, alias): - pass - def visit_select(self, select): - pass - def visit_join(self, join): - pass - def visit_null(self, null): - pass - def visit_clauselist(self, list): - pass - def visit_calculatedclause(self, calcclause): - pass - def visit_grouping(self, gr): - pass - def visit_function(self, func): - pass - def visit_cast(self, cast): - pass - def visit_label(self, label): - pass - def visit_typeclause(self, typeclause): - pass - -class LoggingClauseVisitor(ClauseVisitor): - """extends ClauseVisitor to include debug logging of all traversal. - - To install this visitor, set logging.DEBUG for - 'sqlalchemy.sql.ClauseVisitor' **before** you import the - sqlalchemy.sql module. - """ - - def traverse(self, obj, stop_on=None): - stack = [(obj, "")] - traversal = [] - while len(stack) > 0: - (t, indent) = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, (t, indent)) - for c in t.get_children(**self.__traverse_options__): - stack.append((c, indent + " ")) - - for (target, indent) in traversal: - self.logger.debug(indent + repr(target)) - v = self - while v is not None: - target.accept_visitor(v) - v = getattr(v, '_next', None) - return obj - -LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor) - -if logging.is_debug_enabled(LoggingClauseVisitor.logger): - ClauseVisitor=LoggingClauseVisitor class NoColumnVisitor(ClauseVisitor): """a ClauseVisitor that will not traverse the exported Column @@ -1000,113 +932,35 @@ class NoColumnVisitor(ClauseVisitor): """ __traverse_options__ = {'column_collections':False} - -class Executor(object): - """Interface representing a "thing that can produce Compiled objects - and execute them".""" - - def execute_compiled(self, compiled, parameters, echo=None, **kwargs): - """Execute a Compiled object.""" - - raise NotImplementedError() - - def compiler(self, statement, parameters, **kwargs): - """Return a Compiled object for the given statement and parameters.""" - - raise NotImplementedError() - -class Compiled(ClauseVisitor): - """Represent a compiled SQL expression. - - The ``__str__`` method of the ``Compiled`` object should produce - the actual text of the statement. ``Compiled`` objects are - specific to their underlying database dialect, and also may - or may not be specific to the columns referenced within a - particular set of bind parameters. In no case should the - ``Compiled`` object be dependent on the actual values of those - bind parameters, even though it may reference those values as - defaults. - """ - - def __init__(self, dialect, statement, parameters, bind=None, engine=None): - """Construct a new ``Compiled`` object. - - statement - ``ClauseElement`` to be compiled. - - parameters - Optional dictionary indicating a set of bind parameters - specified with this ``Compiled`` object. These parameters - are the *default* values corresponding to the - ``ClauseElement``'s ``_BindParamClauses`` when the - ``Compiled`` is executed. In the case of an ``INSERT`` or - ``UPDATE`` statement, these parameters will also result in - the creation of new ``_BindParamClause`` objects for each - key and will also affect the generated column list in an - ``INSERT`` statement and the ``SET`` clauses of an - ``UPDATE`` statement. The keys of the parameter dictionary - can either be the string names of columns or - ``_ColumnClause`` objects. - - bind - optional engine or connection which will be bound to the - compiled object. - - engine - deprecated, a synonym for 'bind' - """ - self.dialect = dialect - self.statement = statement - self.parameters = parameters - self.bind = bind or engine - self.can_execute = statement.supports_execution() - - def compile(self): - self.traverse(self.statement) - self.after_compile() - - def __str__(self): - """Return the string text of the generated SQL statement.""" - - raise NotImplementedError() - def get_params(self, **params): - """Deprecated. use construct_params(). (supports unicode names) - """ - return self.construct_params(params) - - def construct_params(self, params): - """Return the bind params for this compiled object. - - Will start with the default parameters specified when this - ``Compiled`` object was first constructed, and will override - those values with those sent via `**params`, which are - key/value pairs. Each key should match one of the - ``_BindParamClause`` objects compiled into this object; either - the `key` or `shortname` property of the ``_BindParamClause``. - """ - raise NotImplementedError() +class _FigureVisitName(type): + def __init__(cls, clsname, bases, dict): + if not '__visit_name__' in cls.__dict__: + m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname) + x = m.group(1) + x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x) + cls.__visit_name__ = x.lower() + super(_FigureVisitName, cls).__init__(clsname, bases, dict) - def execute(self, *multiparams, **params): - """Execute this compiled object.""" - - e = self.bind - if e is None: - raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.") - return e.execute_compiled(self, *multiparams, **params) - - def scalar(self, *multiparams, **params): - """Execute this compiled object and return the result's scalar value.""" - - return self.execute(*multiparams, **params).scalar() - class ClauseElement(object): """Base class for elements of a programmatically constructed SQL expression. """ + __metaclass__ = _FigureVisitName + + def _clone(self): + """create a shallow copy of this ClauseElement. + + This method may be used by a generative API. + Its also used as part of the "deep" copy afforded + by a traversal that combines the _copy_internals() + method.""" + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + return c - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): """Return objects represented in this ``ClauseElement`` that should be added to the ``FROM`` list of a query, when this ``ClauseElement`` is placed in the column clause of a @@ -1115,7 +969,7 @@ class ClauseElement(object): raise NotImplementedError(repr(self)) - def _hide_froms(self): + def _hide_froms(self, **modifiers): """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces. """ @@ -1131,13 +985,14 @@ class ClauseElement(object): return self is other - def accept_visitor(self, visitor): - """Accept a ``ClauseVisitor`` and call the appropriate - ``visit_xxx`` method. - """ - - raise NotImplementedError(repr(self)) - + def _copy_internals(self): + """reassign internal elements to be clones of themselves. + + called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy.""" + + pass + def get_children(self, **kwargs): """return immediate child elements of this ``ClauseElement``. @@ -1160,18 +1015,6 @@ class ClauseElement(object): return False - def copy_container(self): - """Return a copy of this ``ClauseElement``, if this - ``ClauseElement`` contains other ``ClauseElements``. - - If this ``ClauseElement`` is not a container, it should return - self. This is used to create copies of expression trees that - still reference the same *leaf nodes*. The new structure can - then be restructured without affecting the original. - """ - - return self - def _find_engine(self): """Default strategy for locating an engine within the clause element. @@ -1195,7 +1038,6 @@ class ClauseElement(object): return None bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""") - engine = bind def execute(self, *multiparams, **params): """Compile and execute this ``ClauseElement``.""" @@ -1213,7 +1055,7 @@ class ClauseElement(object): return self.execute(*multiparams, **params).scalar() - def compile(self, bind=None, engine=None, parameters=None, compiler=None, dialect=None): + def compile(self, bind=None, parameters=None, compiler=None, dialect=None): """Compile this SQL expression. Uses the given ``Compiler``, or the given ``AbstractDialect`` @@ -1236,7 +1078,7 @@ class ClauseElement(object): ``SET`` and ``VALUES`` clause of those statements. """ - if (isinstance(parameters, list) or isinstance(parameters, tuple)): + if isinstance(parameters, (list, tuple)): parameters = parameters[0] if compiler is None: @@ -1244,8 +1086,6 @@ class ClauseElement(object): compiler = dialect.compiler(self, parameters) elif bind is not None: compiler = bind.compiler(self, parameters) - elif engine is not None: - compiler = engine.compiler(self, parameters) elif self.bind is not None: compiler = self.bind.compiler(self, parameters) @@ -1268,49 +1108,257 @@ class ClauseElement(object): return self._negate() def _negate(self): - return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None) + if hasattr(self, 'negation_clause'): + return self.negation_clause + else: + return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None) + -class _CompareMixin(object): - """Defines comparison operations for ``ClauseElement`` instances. +class Operators(object): + def from_(): + raise NotImplementedError() + from_ = staticmethod(from_) - This is a mixin class that adds the capability to produce ``ClauseElement`` - instances based on regular Python operators. - These operations are achieved using Python's operator overload methods - (i.e. ``__eq__()``, ``__ne__()``, etc. + def as_(): + raise NotImplementedError() + as_ = staticmethod(as_) - Overridden operators include all comparison operators (i.e. '==', '!=', '<'), - math operators ('+', '-', '*', etc), the '&' and '|' operators which evaluate - to ``AND`` and ``OR`` respectively. + def exists(): + raise NotImplementedError() + exists = staticmethod(exists) - Other methods exist to create additional SQL clauses such as ``IN``, ``LIKE``, - ``DISTINCT``, etc. + def is_(): + raise NotImplementedError() + is_ = staticmethod(is_) - """ + def isnot(): + raise NotImplementedError() + isnot = staticmethod(isnot) + + def __and__(self, other): + return self.operate(operator.and_, other) + + def __or__(self, other): + return self.operate(operator.or_, other) + + def __invert__(self): + return self.operate(operator.inv) + + def clause_element(self): + raise NotImplementedError() + + def operate(self, op, *other, **kwargs): + raise NotImplementedError() + + def reverse_operate(self, op, *other, **kwargs): + raise NotImplementedError() + +class ColumnOperators(Operators): + """defines comparison and math operations""" + + def like_op(a, b): + return a.like(b) + like_op = staticmethod(like_op) + + def notlike_op(a, b): + raise NotImplementedError() + notlike_op = staticmethod(notlike_op) + def ilike_op(a, b): + return a.ilike(b) + ilike_op = staticmethod(ilike_op) + + def notilike_op(a, b): + raise NotImplementedError() + notilike_op = staticmethod(notilike_op) + + def between_op(a, b): + return a.between(b) + between_op = staticmethod(between_op) + + def in_op(a, b): + return a.in_(*b) + in_op = staticmethod(in_op) + + def notin_op(a, b): + raise NotImplementedError() + notin_op = staticmethod(notin_op) + + def startswith_op(a, b): + return a.startswith(b) + startswith_op = staticmethod(startswith_op) + + def endswith_op(a, b): + return a.endswith(b) + endswith_op = staticmethod(endswith_op) + + def comma_op(a, b): + raise NotImplementedError() + comma_op = staticmethod(comma_op) + + def concat_op(a, b): + return a.concat(b) + concat_op = staticmethod(concat_op) + def __lt__(self, other): - return self._compare('<', other) + return self.operate(operator.lt, other) def __le__(self, other): - return self._compare('<=', other) + return self.operate(operator.le, other) def __eq__(self, other): - return self._compare('=', other) + return self.operate(operator.eq, other) def __ne__(self, other): - return self._compare('!=', other) + return self.operate(operator.ne, other) def __gt__(self, other): - return self._compare('>', other) + return self.operate(operator.gt, other) def __ge__(self, other): - return self._compare('>=', other) + return self.operate(operator.ge, other) + def concat(self, other): + return self.operate(ColumnOperators.concat_op, other) + def like(self, other): - """produce a ``LIKE`` clause.""" - return self._compare('LIKE', other) + return self.operate(ColumnOperators.like_op, other) + + def in_(self, *other): + return self.operate(ColumnOperators.in_op, other) + + def startswith(self, other): + return self.operate(ColumnOperators.startswith_op, other) + + def endswith(self, other): + return self.operate(ColumnOperators.endswith_op, other) + + def __radd__(self, other): + return self.reverse_operate(operator.add, other) + + def __rsub__(self, other): + return self.reverse_operate(operator.sub, other) + + def __rmul__(self, other): + return self.reverse_operate(operator.mul, other) + + def __rdiv__(self, other): + return self.reverse_operate(operator.div, other) + + def between(self, cleft, cright): + return self.operate(Operators.between_op, (cleft, cright)) + + def __add__(self, other): + return self.operate(operator.add, other) + + def __sub__(self, other): + return self.operate(operator.sub, other) + + def __mul__(self, other): + return self.operate(operator.mul, other) + + def __div__(self, other): + return self.operate(operator.div, other) + + def __mod__(self, other): + return self.operate(operator.mod, other) + + def __truediv__(self, other): + return self.operate(operator.truediv, other) + +# precedence ordering for common operators. if an operator is not present in this list, +# it will be parenthesized when grouped against other operators +_smallest = object() +_largest = object() + +PRECEDENCE = { + Operators.from_:15, + operator.mul:7, + operator.div:7, + operator.mod:7, + operator.add:6, + operator.sub:6, + ColumnOperators.concat_op:6, + ColumnOperators.ilike_op:5, + ColumnOperators.notilike_op:5, + ColumnOperators.like_op:5, + ColumnOperators.notlike_op:5, + ColumnOperators.in_op:5, + ColumnOperators.notin_op:5, + Operators.is_:5, + Operators.isnot:5, + operator.eq:5, + operator.ne:5, + operator.gt:5, + operator.lt:5, + operator.ge:5, + operator.le:5, + ColumnOperators.between_op:5, + operator.inv:4, + operator.and_:3, + operator.or_:2, + ColumnOperators.comma_op:-1, + Operators.as_:-1, + Operators.exists:0, + _smallest: -1000, + _largest: 1000 +} + +class _CompareMixin(ColumnOperators): + """Defines comparison and math operations for ``ClauseElement`` instances.""" + + def __compare(self, op, obj, negate=None): + if obj is None or isinstance(obj, _Null): + if op == operator.eq: + return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot) + elif op == operator.ne: + return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_) + else: + raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") + else: + obj = self._check_literal(obj) + + + return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate) + + def __operate(self, op, obj): + obj = self._check_literal(obj) + + type_ = self._compare_type(obj) + + # TODO: generalize operator overloading like this out into the types module + if op == operator.add and isinstance(type_, (sqltypes.Concatenable)): + op = ColumnOperators.concat_op + + return _BinaryExpression(self.expression_element(), obj, op, type_=type_) + + operators = { + operator.add : (__operate,), + operator.mul : (__operate,), + operator.sub : (__operate,), + operator.div : (__operate,), + operator.mod : (__operate,), + operator.truediv : (__operate,), + operator.lt : (__compare, operator.ge), + operator.le : (__compare, operator.gt), + operator.ne : (__compare, operator.eq), + operator.gt : (__compare, operator.le), + operator.ge : (__compare, operator.lt), + operator.eq : (__compare, operator.ne), + ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op), + } + + def operate(self, op, other): + o = _CompareMixin.operators[op] + return o[0](self, op, other, *o[1:]) + + def reverse_operate(self, op, other): + return self._bind_param(other).operate(op, self) def in_(self, *other): - """produce an ``IN`` clause.""" + return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other) + + def _in_impl(self, op, negate_op, *other): if len(other) == 0: return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1'))) elif len(other) == 1: @@ -1318,8 +1366,8 @@ class _CompareMixin(object): if _is_literal(o) or isinstance( o, _CompareMixin): return self.__eq__( o) #single item -> == else: - assert hasattr( o, '_selectable') #better check? - return self._compare( 'IN', o, negate='NOT IN') #single selectable + assert isinstance(o, Selectable) + return self.__compare( op, o, negate=negate_op) #single selectable args = [] for o in other: @@ -1329,29 +1377,22 @@ class _CompareMixin(object): else: o = self._bind_param(o) args.append(o) - return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN') + return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op) def startswith(self, other): """produce the clause ``LIKE '<other>%'``""" - perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String) - return self._compare('LIKE', other + perc) + + perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String) + return self.__compare(ColumnOperators.like_op, other + perc) def endswith(self, other): """produce the clause ``LIKE '%<other>'``""" + if isinstance(other,(str,unicode)): po = '%' + other else: - po = literal('%', type= sqltypes.String) + other - po.type = sqltypes.to_instance( sqltypes.String) #force! - return self._compare('LIKE', po) - - def __radd__(self, other): - return self._bind_param(other)._operate('+', self) - def __rsub__(self, other): - return self._bind_param(other)._operate('-', self) - def __rmul__(self, other): - return self._bind_param(other)._operate('*', self) - def __rdiv__(self, other): - return self._bind_param(other)._operate('/', self) + po = literal('%', type_=sqltypes.String) + other + po.type = sqltypes.to_instance(sqltypes.String) #force! + return self.__compare(ColumnOperators.like_op, po) def label(self, name): """produce a column label, i.e. ``<columnname> AS <name>``""" @@ -1363,7 +1404,8 @@ class _CompareMixin(object): def between(self, cleft, cright): """produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``""" - return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator='AND', group=False), 'BETWEEN') + + return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op) def op(self, operator): """produce a generic operator function. @@ -1382,59 +1424,25 @@ class _CompareMixin(object): passed to the generated function. """ - return lambda other: self._operate(operator, other) - - # and here come the math operators: - - def __add__(self, other): - return self._operate('+', other) - - def __sub__(self, other): - return self._operate('-', other) - - def __mul__(self, other): - return self._operate('*', other) - - def __div__(self, other): - return self._operate('/', other) - - def __mod__(self, other): - return self._operate('%', other) - - def __truediv__(self, other): - return self._operate('/', other) + return lambda other: self.__operate(operator, other) def _bind_param(self, obj): - return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True) + return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True) def _check_literal(self, other): - if _is_literal(other): + if isinstance(other, Operators): + return other.expression_element() + elif _is_literal(other): return self._bind_param(other) else: return other + + def clause_element(self): + """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``.""" + return self - def _compare(self, operator, obj, negate=None): - if obj is None or isinstance(obj, _Null): - if operator == '=': - return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT') - elif operator == '!=': - return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS') - else: - raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") - else: - obj = self._check_literal(obj) - - return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate) - - def _operate(self, operator, obj): - if _is_literal(obj): - obj = self._bind_param(obj) - return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj)) - - def _compare_self(self): - """Allow ``ColumnImpl`` to return its ``Column`` object for - usage in ``ClauseElements``, all others to just return self. - """ + def expression_element(self): + """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" return self @@ -1460,23 +1468,10 @@ class Selectable(ClauseElement): columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""") - def _selectable(self): - return self - - def accept_visitor(self, visitor): - raise NotImplementedError(repr(self)) - def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - def _group_parenthesized(self): - """Indicate if this ``Selectable`` requires parenthesis when - grouped into a compound statement. - """ - return True - - class ColumnElement(Selectable, _CompareMixin): """Represent an element that is useable within the "column clause" portion of a ``SELECT`` statement. @@ -1616,8 +1611,10 @@ class ColumnCollection(util.OrderedProperties): l.append(c==local) return and_(*l) - def __contains__(self, col): - return self.contains_column(col) + def __contains__(self, other): + if not isinstance(other, basestring): + raise exceptions.ArgumentError("__contains__ requires a string argument") + return self.has_key(other) def contains_column(self, col): # have to use a Set here, because it will compare the identity @@ -1649,19 +1646,18 @@ class FromClause(Selectable): clause of a ``SELECT`` statement. """ + __visit_name__ = 'fromclause' + def __init__(self, name=None): self.name = name - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): # this could also be [self], at the moment it doesnt matter to the Select object return [] def default_order_by(self): return [self.oid_column] - def accept_visitor(self, visitor): - visitor.visit_fromclause(self) - def count(self, whereclause=None, **params): if len(self.primary_key): col = list(self.primary_key)[0] @@ -1703,6 +1699,13 @@ class FromClause(Selectable): FindCols().traverse(self) return ret + def is_derived_from(self, fromclause): + """return True if this FromClause is 'derived' from the given FromClause. + + An example would be an Alias of a Table is derived from that Table.""" + + return False + def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False): """Given a ``ColumnElement``, return the exported ``ColumnElement`` object from this ``Selectable`` which @@ -1730,9 +1733,10 @@ class FromClause(Selectable): it merely shares a common anscestor with one of the exported columns of this ``FromClause``. """ - if column in self.c: + + if self.c.contains_column(column): return column - + if require_embedded and column not in util.Set(self._get_all_embedded_columns()): if not raiseerr: return None @@ -1761,6 +1765,15 @@ class FromClause(Selectable): self._export_columns() return getattr(self, name) + def _clone_from_clause(self): + # delete all the "generated" collections of columns for a newly cloned FromClause, + # so that they will be re-derived from the item. + # this is because FromClause subclasses, when cloned, need to reestablish new "proxied" + # columns that are linked to the new item + for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'): + if hasattr(self, attr): + delattr(self, attr) + columns = property(lambda s:s._get_exported_attribute('_columns')) c = property(lambda s:s._get_exported_attribute('_columns')) primary_key = property(lambda s:s._get_exported_attribute('_primary_key')) @@ -1791,8 +1804,9 @@ class FromClause(Selectable): self._primary_key = ColumnSet() self._foreign_keys = util.Set() self._orig_cols = {} + if columns is None: - columns = self._adjusted_exportable_columns() + columns = self._flatten_exportable_columns() for co in columns: cp = self._proxy_column(co) for ci in cp.orig_set: @@ -1806,13 +1820,14 @@ class FromClause(Selectable): for ci in self.oid_column.orig_set: self._orig_cols[ci] = self.oid_column - def _adjusted_exportable_columns(self): + def _flatten_exportable_columns(self): """return the list of ColumnElements represented within this FromClause's _exportable_columns""" export = self._exportable_columns() for column in export: - try: - s = column._selectable() - except AttributeError: + # TODO: is this conditional needed ? + if isinstance(column, Selectable): + s = column + else: continue for co in s.columns: yield co @@ -1829,7 +1844,9 @@ class _BindParamClause(ClauseElement, _CompareMixin): Public constructor is the ``bindparam()`` function. """ - def __init__(self, key, value, shortname=None, type=None, unique=False): + __visit_name__ = 'bindparam' + + def __init__(self, key, value, shortname=None, type_=None, unique=False, isoutparam=False): """Construct a _BindParamClause. key @@ -1852,7 +1869,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): execution may match either the key or the shortname of the corresponding ``_BindParamClause`` objects. - type + type\_ A ``TypeEngine`` object that will be used to pre-process the value corresponding to this ``_BindParamClause`` at execution time. @@ -1862,23 +1879,34 @@ class _BindParamClause(ClauseElement, _CompareMixin): modified if another ``_BindParamClause`` of the same name already has been located within the containing ``ClauseElement``. + + isoutparam + if True, the parameter should be treated like a stored procedure "OUT" + parameter. """ - self.key = key + self.key = key or "{ANON %d param}" % id(self) self.value = value self.shortname = shortname or key self.unique = unique - self.type = sqltypes.to_instance(type) - - def accept_visitor(self, visitor): - visitor.visit_bindparam(self) - - def _get_from_objects(self): + self.isoutparam = isoutparam + type_ = sqltypes.to_instance(type_) + if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map: + self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)]) + else: + self.type = type_ + + # TODO: move to types module, obviously + type_map = { + str : sqltypes.String, + unicode : sqltypes.Unicode, + int : sqltypes.Integer, + float : sqltypes.Numeric + } + + def _get_from_objects(self, **modifiers): return [] - def copy_container(self): - return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique) - def typeprocess(self, value, dialect): return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) @@ -1893,7 +1921,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ def __repr__(self): - return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type)) + return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type)) class _TypeClause(ClauseElement): """Handle a type keyword in a SQL statement. @@ -1901,13 +1929,12 @@ class _TypeClause(ClauseElement): Used by the ``Case`` statement. """ + __visit_name__ = 'typeclause' + def __init__(self, type): self.type = type - def accept_visitor(self, visitor): - visitor.visit_typeclause(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] class _TextClause(ClauseElement): @@ -1916,8 +1943,10 @@ class _TextClause(ClauseElement): Public constructor is the ``text()`` function. """ - def __init__(self, text = "", bind=None, engine=None, bindparams=None, typemap=None): - self._bind = bind or engine + __visit_name__ = 'textclause' + + def __init__(self, text = "", bind=None, bindparams=None, typemap=None): + self._bind = bind self.bindparams = {} self.typemap = typemap if typemap is not None: @@ -1930,7 +1959,7 @@ class _TextClause(ClauseElement): # scan the string and search for bind parameter names, add them # to the list of bindparams - self.text = re.compile(r'(?<!:):([\w_]+)', re.S).sub(repl, text) + self.text = BIND_PARAMS.sub(repl, text) if bindparams is not None: for b in bindparams: self.bindparams[b.key] = b @@ -1944,13 +1973,13 @@ class _TextClause(ClauseElement): columns = property(lambda s:[]) + def _copy_internals(self): + self.bindparams = [b._clone() for b in self.bindparams] + def get_children(self, **kwargs): return self.bindparams.values() - def accept_visitor(self, visitor): - visitor.visit_textclause(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] def supports_execution(self): @@ -1965,10 +1994,7 @@ class _Null(ColumnElement): def __init__(self): self.type = sqltypes.NULLTYPE - def accept_visitor(self, visitor): - visitor.visit_null(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] class ClauseList(ClauseElement): @@ -1976,14 +2002,16 @@ class ClauseList(ClauseElement): By default, is comma-separated, such as a column listing. """ - + __visit_name__ = 'clauselist' + def __init__(self, *clauses, **kwargs): self.clauses = [] - self.operator = kwargs.pop('operator', ',') + self.operator = kwargs.pop('operator', ColumnOperators.comma_op) self.group = kwargs.pop('group', True) self.group_contents = kwargs.pop('group_contents', True) for c in clauses: - if c is None: continue + if c is None: + continue self.append(c) def __iter__(self): @@ -1991,32 +2019,28 @@ class ClauseList(ClauseElement): def __len__(self): return len(self.clauses) - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return ClauseList(operator=self.operator, *clauses) - def append(self, clause): # TODO: not sure if i like the 'group_contents' flag. need to define the difference between # a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ? if self.group_contents: - self.clauses.append(_literals_as_text(clause).self_group(against=self.operator)) + self.clauses.append(_literal_as_text(clause).self_group(against=self.operator)) else: - self.clauses.append(_literals_as_text(clause)) + self.clauses.append(_literal_as_text(clause)) + + def _copy_internals(self): + self.clauses = [clause._clone() for clause in self.clauses] def get_children(self, **kwargs): return self.clauses - def accept_visitor(self, visitor): - visitor.visit_clauselist(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): f = [] for c in self.clauses: - f += c._get_from_objects() + f += c._get_from_objects(**modifiers) return f def self_group(self, against=None): - if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']): + if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]): return _Grouping(self) else: return self @@ -2043,40 +2067,45 @@ class _CalculatedClause(ColumnElement): Extends ``ColumnElement`` to provide column-level comparison operators. """ - + __visit_name__ = 'calculatedclause' + def __init__(self, name, *clauses, **kwargs): self.name = name - self.type = sqltypes.to_instance(kwargs.get('type', None)) - self._bind = kwargs.get('bind', kwargs.get('engine', None)) + self.type = sqltypes.to_instance(kwargs.get('type_', None)) + self._bind = kwargs.get('bind', None) self.group = kwargs.pop('group', True) - self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) + clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) if self.group: - self.clause_expr = self.clauses.self_group() + self.clause_expr = clauses.self_group() else: - self.clause_expr = self.clauses + self.clause_expr = clauses key = property(lambda self:self.name or "_calc_") - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return _CalculatedClause(type=self.type, bind=self._bind, *clauses) - + def _copy_internals(self): + self.clause_expr = self.clause_expr._clone() + + def clauses(self): + if isinstance(self.clause_expr, _Grouping): + return self.clause_expr.elem + else: + return self.clause_expr + clauses = property(clauses) + def get_children(self, **kwargs): return self.clause_expr, - def accept_visitor(self, visitor): - visitor.visit_calculatedclause(self) - def _get_from_objects(self): - return self.clauses._get_from_objects() + def _get_from_objects(self, **modifiers): + return self.clauses._get_from_objects(**modifiers) def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type=self.type, unique=True) + return _BindParamClause(self.name, obj, type_=self.type, unique=True) def select(self): return select([self]) def scalar(self): - return select([self]).scalar() + return select([self]).execute().scalar() def execute(self): return select([self]).execute() @@ -2092,28 +2121,26 @@ class _Function(_CalculatedClause, FromClause): """ def __init__(self, name, *clauses, **kwargs): - self.type = sqltypes.to_instance(kwargs.get('type', None)) self.packagenames = kwargs.get('packagenames', None) or [] - kwargs['operator'] = ',' - self._engine = kwargs.get('engine', None) + kwargs['operator'] = ColumnOperators.comma_op _CalculatedClause.__init__(self, name, **kwargs) for c in clauses: self.append(c) key = property(lambda self:self.name) + def _copy_internals(self): + _CalculatedClause._copy_internals(self) + self._clone_from_clause() - def append(self, clause): - self.clauses.append(_literals_as_binds(clause, self.name)) - - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return _Function(self.name, type=self.type, packagenames=self.packagenames, bind=self._bind, *clauses) + def get_children(self, **kwargs): + return _CalculatedClause.get_children(self, **kwargs) - def accept_visitor(self, visitor): - visitor.visit_function(self) + def append(self, clause): + self.clauses.append(_literal_as_binds(clause, self.name)) class _Cast(ColumnElement): + def __init__(self, clause, totype, **kwargs): if not hasattr(clause, 'label'): clause = literal(clause) @@ -2122,17 +2149,19 @@ class _Cast(ColumnElement): self.typeclause = _TypeClause(self.type) self._distance = 0 + def _copy_internals(self): + self.clause = self.clause._clone() + self.typeclause = self.typeclause._clone() + def get_children(self, **kwargs): return self.clause, self.typeclause - def accept_visitor(self, visitor): - visitor.visit_cast(self) - def _get_from_objects(self): - return self.clause._get_from_objects() + def _get_from_objects(self, **modifiers): + return self.clause._get_from_objects(**modifiers) def _make_proxy(self, selectable, name=None): if name is not None: - co = _ColumnClause(name, selectable, type=self.type) + co = _ColumnClause(name, selectable, type_=self.type) co._distance = self._distance + 1 co.orig_set = self.orig_set selectable.columns[name]= co @@ -2142,26 +2171,23 @@ class _Cast(ColumnElement): class _UnaryExpression(ColumnElement): - def __init__(self, element, operator=None, modifier=None, type=None, negate=None): + def __init__(self, element, operator=None, modifier=None, type_=None, negate=None): self.operator = operator self.modifier = modifier - self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier) - self.type = sqltypes.to_instance(type) + self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier) + self.type = sqltypes.to_instance(type_) self.negate = negate - def copy_container(self): - return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate) + def _get_from_objects(self, **modifiers): + return self.element._get_from_objects(**modifiers) - def _get_from_objects(self): - return self.element._get_from_objects() + def _copy_internals(self): + self.element = self.element._clone() def get_children(self, **kwargs): return self.element, - def accept_visitor(self, visitor): - visitor.visit_unary(self) - def compare(self, other): """Compare this ``_UnaryExpression`` against the given ``ClauseElement``.""" @@ -2170,14 +2196,15 @@ class _UnaryExpression(ColumnElement): self.modifier == other.modifier and self.element.compare(other.element) ) + def _negate(self): if self.negate is not None: - return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type) + return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type) else: return super(_UnaryExpression, self)._negate() def self_group(self, against): - if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']): + if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]): return _Grouping(self) else: return self @@ -2186,25 +2213,23 @@ class _UnaryExpression(ColumnElement): class _BinaryExpression(ColumnElement): """Represent an expression that is ``LEFT <operator> RIGHT``.""" - def __init__(self, left, right, operator, type=None, negate=None): - self.left = _literals_as_text(left).self_group(against=operator) - self.right = _literals_as_text(right).self_group(against=operator) + def __init__(self, left, right, operator, type_=None, negate=None): + self.left = _literal_as_text(left).self_group(against=operator) + self.right = _literal_as_text(right).self_group(against=operator) self.operator = operator - self.type = sqltypes.to_instance(type) + self.type = sqltypes.to_instance(type_) self.negate = negate - def copy_container(self): - return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator) + def _get_from_objects(self, **modifiers): + return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - def _get_from_objects(self): - return self.left._get_from_objects() + self.right._get_from_objects() + def _copy_internals(self): + self.left = self.left._clone() + self.right = self.right._clone() def get_children(self, **kwargs): return self.left, self.right - def accept_visitor(self, visitor): - visitor.visit_binary(self) - def compare(self, other): """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" @@ -2213,7 +2238,7 @@ class _BinaryExpression(ColumnElement): ( self.left.compare(other.left) and self.right.compare(other.right) or ( - self.operator in ['=', '!=', '+', '*'] and + self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and self.left.compare(other.right) and self.right.compare(other.left) ) ) @@ -2221,25 +2246,27 @@ class _BinaryExpression(ColumnElement): def self_group(self, against=None): # use small/large defaults for comparison so that unknown operators are always parenthesized - if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])): + if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])): return _Grouping(self) else: return self def _negate(self): if self.negate is not None: - return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type) + return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type) else: return super(_BinaryExpression, self)._negate() class _Exists(_UnaryExpression): + __visit_name__ = _UnaryExpression.__visit_name__ + def __init__(self, *args, **kwargs): kwargs['correlate'] = True s = select(*args, **kwargs).self_group() - _UnaryExpression.__init__(self, s, operator="EXISTS") + _UnaryExpression.__init__(self, s, operator=Operators.exists) - def _hide_froms(self): - return self._get_from_objects() + def _hide_froms(self, **modifiers): + return self._get_from_objects(**modifiers) class Join(FromClause): """represent a ``JOIN`` construct between two ``FromClause`` @@ -2251,8 +2278,8 @@ class Join(FromClause): """ def __init__(self, left, right, onclause=None, isouter = False): - self.left = left._selectable() - self.right = right._selectable() + self.left = _selectable(left) + self.right = _selectable(right).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) else: @@ -2265,8 +2292,8 @@ class Join(FromClause): encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace')) def _init_primary_key(self): - pkcol = util.Set([c for c in self._adjusted_exportable_columns() if c.primary_key]) - + pkcol = util.Set([c for c in self._flatten_exportable_columns() if c.primary_key]) + equivs = {} def add_equiv(a, b): for x, y in ((a, b), (b, a)): @@ -2277,7 +2304,7 @@ class Join(FromClause): class BinaryVisitor(ClauseVisitor): def visit_binary(self, binary): - if binary.operator == '=': + if binary.operator == operator.eq: add_equiv(binary.left, binary.right) BinaryVisitor().traverse(self.onclause) @@ -2294,9 +2321,12 @@ class Join(FromClause): omit.add(p) p = c - self.__primary_key = ColumnSet([c for c in self._adjusted_exportable_columns() if c.primary_key and c not in omit]) + self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit]) primary_key = property(lambda s:s.__primary_key) + + def self_group(self, against=None): + return _Grouping(self) def _locate_oid_column(self): return self.left.oid_column @@ -2310,6 +2340,17 @@ class Join(FromClause): self._foreign_keys.add(f) return column + def _copy_internals(self): + self._clone_from_clause() + self.left = self.left._clone() + self.right = self.right._clone() + self.onclause = self.onclause._clone() + self.__folded_equivalents = None + self._init_primary_key() + + def get_children(self, **kwargs): + return self.left, self.right, self.onclause + def _match_primaries(self, primary, secondary): crit = [] constraints = util.Set() @@ -2338,9 +2379,6 @@ class Join(FromClause): else: return and_(*crit) - def _group_parenthesized(self): - return True - def _get_folded_equivalents(self, equivs=None): if self.__folded_equivalents is not None: return self.__folded_equivalents @@ -2348,7 +2386,7 @@ class Join(FromClause): equivs = util.Set() class LocateEquivs(NoColumnVisitor): def visit_binary(self, binary): - if binary.operator == '=' and binary.left.name == binary.right.name: + if binary.operator == operator.eq and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) LocateEquivs().traverse(self.onclause) @@ -2401,13 +2439,7 @@ class Join(FromClause): return select(collist, whereclause, from_obj=[self], **kwargs) - def get_children(self, **kwargs): - return self.left, self.right, self.onclause - - def accept_visitor(self, visitor): - visitor.visit_join(self) - - engine = property(lambda s:s.left.engine or s.right.engine) + bind = property(lambda s:s.left.bind or s.right.bind) def alias(self, name=None): """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it. @@ -2417,11 +2449,11 @@ class Join(FromClause): return self.select(use_labels=True, correlate=False).alias(name) - def _hide_froms(self): - return self.left._get_from_objects() + self.right._get_from_objects() + def _hide_froms(self, **modifiers): + return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - def _get_from_objects(self): - return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects() + def _get_from_objects(self, **modifiers): + return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) class Alias(FromClause): """represent an alias, as typically applied to any @@ -2443,15 +2475,22 @@ class Alias(FromClause): if alias is None: if self.original.named_with_column(): alias = getattr(self.original, 'name', None) - if alias is None: - alias = 'anon' - elif len(alias) > 15: - alias = alias[0:15] - alias = alias + "_" + hex(random.randint(0, 65535))[2:] + alias = '{ANON %d %s}' % (id(self), alias or 'anon') self.name = alias self.encodedname = alias.encode('ascii', 'backslashreplace') self.case_sensitive = getattr(baseselectable, "case_sensitive", True) + def is_derived_from(self, fromclause): + x = self.selectable + while True: + if x is fromclause: + return True + if isinstance(x, Alias): + x = x.selectable + else: + break + return False + def supports_execution(self): return self.original.supports_execution() @@ -2468,46 +2507,57 @@ class Alias(FromClause): #return self.selectable._exportable_columns() return self.selectable.columns + def _copy_internals(self): + self._clone_from_clause() + self.selectable = self.selectable._clone() + baseselectable = self.selectable + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable + def get_children(self, **kwargs): for c in self.c: yield c yield self.selectable - def accept_visitor(self, visitor): - visitor.visit_alias(self) - def _get_from_objects(self): return [self] - def _group_parenthesized(self): - return False - bind = property(lambda s: s.selectable.bind) - engine = bind -class _Grouping(ColumnElement): +class _ColumnElementAdapter(ColumnElement): + """adapts a ClauseElement which may or may not be a + ColumnElement subclass itself into an object which + acts like a ColumnElement. + """ + def __init__(self, elem): self.elem = elem self.type = getattr(elem, 'type', None) - + self.orig_set = getattr(elem, 'orig_set', util.Set()) + key = property(lambda s: s.elem.key) _label = property(lambda s: s.elem._label) - orig_set = property(lambda s:s.elem.orig_set) - - def copy_container(self): - return _Grouping(self.elem.copy_container()) - - def accept_visitor(self, visitor): - visitor.visit_grouping(self) + columns = c = property(lambda s:s.elem.columns) + + def _copy_internals(self): + self.elem = self.elem._clone() + def get_children(self, **kwargs): return self.elem, - def _hide_froms(self): - return self.elem._hide_froms() - def _get_from_objects(self): - return self.elem._get_from_objects() + + def _hide_froms(self, **modifiers): + return self.elem._hide_froms(**modifiers) + + def _get_from_objects(self, **modifiers): + return self.elem._get_from_objects(**modifiers) + def __getattr__(self, attr): return getattr(self.elem, attr) - + +class _Grouping(_ColumnElementAdapter): + pass + class _Label(ColumnElement): """represent a label, as typically applied to any column-level element using the ``AS`` sql keyword. @@ -2518,32 +2568,33 @@ class _Label(ColumnElement): """ - def __init__(self, name, obj, type=None): - self.name = name + def __init__(self, name, obj, type_=None): while isinstance(obj, _Label): obj = obj.obj - self.obj = obj.self_group(against='AS') + self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) + + self.obj = obj.self_group(against=Operators.as_) self.case_sensitive = getattr(obj, "case_sensitive", True) - self.type = sqltypes.to_instance(type or getattr(obj, 'type', None)) + self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) key = property(lambda s: s.name) _label = property(lambda s: s.name) orig_set = property(lambda s:s.obj.orig_set) - def _compare_self(self): + def expression_element(self): return self.obj - + + def _copy_internals(self): + self.obj = self.obj._clone() + def get_children(self, **kwargs): return self.obj, - def accept_visitor(self, visitor): - visitor.visit_label(self) + def _get_from_objects(self, **modifiers): + return self.obj._get_from_objects(**modifiers) - def _get_from_objects(self): - return self.obj._get_from_objects() - - def _hide_froms(self): - return self.obj._hide_froms() + def _hide_froms(self, **modifiers): + return self.obj._hide_froms(**modifiers) def _make_proxy(self, selectable, name = None): if isinstance(self.obj, Selectable): @@ -2551,8 +2602,6 @@ class _Label(ColumnElement): else: return column(self.name)._make_proxy(selectable=selectable) -legal_characters = util.Set(string.ascii_letters + string.digits + '_') - class _ColumnClause(ColumnElement): """Represents a generic column expression from any textual string. This includes columns associated with tables, aliases and select @@ -2584,17 +2633,21 @@ class _ColumnClause(ColumnElement): """ - def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False): + def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False): self.key = self.name = text self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name self.table = selectable - self.type = sqltypes.to_instance(type) + self.type = sqltypes.to_instance(type_) self._is_oid = _is_oid self._distance = 0 self.__label = None self.case_sensitive = case_sensitive self.is_literal = is_literal - + + def _clone(self): + # ColumnClause is immutable + return self + def _get_label(self): """Generate a 'label' for this column. @@ -2617,7 +2670,6 @@ class _ColumnClause(ColumnElement): counter += 1 else: self.__label = self.name - self.__label = "".join([x for x in self.__label if x in legal_characters]) return self.__label is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name) @@ -2632,23 +2684,20 @@ class _ColumnClause(ColumnElement): else: return super(_ColumnClause, self).label(name) - def accept_visitor(self, visitor): - visitor.visit_column(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): if self.table is not None: return [self.table] else: return [] def _bind_param(self, obj): - return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True) + return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True) def _make_proxy(self, selectable, name = None): # propigate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) - c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal) + c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) c.orig_set = self.orig_set c._distance = self._distance + 1 if not self._is_oid: @@ -2658,9 +2707,6 @@ class _ColumnClause(ColumnElement): def _compare_type(self, obj): return self.type - def _group_parenthesized(self): - return False - class TableClause(FromClause): """represents a "table" construct. @@ -2677,6 +2723,10 @@ class TableClause(FromClause): self._oid_column = _ColumnClause('oid', self, _is_oid=True) self._export_columns(columns) + def _clone(self): + # TableClause is immutable + return self + def named_with_column(self): return True @@ -2709,15 +2759,9 @@ class TableClause(FromClause): else: return [] - def accept_visitor(self, visitor): - visitor.visit_table(self) - def _exportable_columns(self): raise NotImplementedError() - def _group_parenthesized(self): - return False - def count(self, whereclause=None, **params): if len(self.primary_key): col = list(self.primary_key)[0] @@ -2746,68 +2790,120 @@ class TableClause(FromClause): def delete(self, whereclause = None): return delete(self, whereclause) - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [self] + class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None): + self.use_labels = use_labels + self.for_update = for_update + self._limit = limit + self._offset = offset + self._bind = bind + + self.append_order_by(*util.to_list(order_by, [])) + self.append_group_by(*util.to_list(group_by, [])) + + def as_scalar(self): + return _ScalarSelect(self) + + def label(self, name): + return self.as_scalar().label(name) + def supports_execution(self): return True + def _generate(self): + s = self._clone() + s._clone_from_clause() + return s + + def limit(self, limit): + s = self._generate() + s._limit = limit + return s + + def offset(self, offset): + s = self._generate() + s._offset = offset + return s + def order_by(self, *clauses): - if len(clauses) == 1 and clauses[0] is None: - self.order_by_clause = ClauseList() - elif getattr(self, 'order_by_clause', None): - self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses))) - else: - self.order_by_clause = ClauseList(*clauses) + s = self._generate() + s.append_order_by(*clauses) + return s def group_by(self, *clauses): - if len(clauses) == 1 and clauses[0] is None: - self.group_by_clause = ClauseList() - elif getattr(self, 'group_by_clause', None): - self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses))) + s = self._generate() + s.append_group_by(*clauses) + return s + + def append_order_by(self, *clauses): + if clauses == [None]: + self._order_by_clause = ClauseList() else: - self.group_by_clause = ClauseList(*clauses) + if getattr(self, '_order_by_clause', None): + clauses = list(self._order_by_clause) + list(clauses) + self._order_by_clause = ClauseList(*clauses) + def append_group_by(self, *clauses): + if clauses == [None]: + self._group_by_clause = ClauseList() + else: + if getattr(self, '_group_by_clause', None): + clauses = list(self._group_by_clause) + list(clauses) + self._group_by_clause = ClauseList(*clauses) + def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - def _get_from_objects(self): - if self.is_where or self.is_scalar: + def _get_from_objects(self, is_where=False, **modifiers): + if is_where: return [] else: return [self] +class _ScalarSelect(_Grouping): + __visit_name__ = 'grouping' + + def __init__(self, elem): + super(_ScalarSelect, self).__init__(elem) + self.type = list(elem.inner_columns)[0].type + + columns = property(lambda self:[self]) + + def self_group(self, **kwargs): + return self + + def _make_proxy(self, selectable, name): + return list(self.inner_columns)[0]._make_proxy(selectable, name) + + def _get_from_objects(self, **modifiers): + return [] + class CompoundSelect(_SelectBaseMixin, FromClause): def __init__(self, keyword, *selects, **kwargs): - _SelectBaseMixin.__init__(self) + self._should_correlate = kwargs.pop('correlate', False) self.keyword = keyword - self.use_labels = kwargs.pop('use_labels', False) - self.should_correlate = kwargs.pop('correlate', False) - self.for_update = kwargs.pop('for_update', False) - self.nowait = kwargs.pop('nowait', False) - self.limit = kwargs.pop('limit', None) - self.offset = kwargs.pop('offset', None) - self.is_compound = True - self.is_where = False - self.is_scalar = False - self.is_subquery = False - - # unions group from left to right, so don't group first select - self.selects = [n and select.self_group(self) or select for n,select in enumerate(selects)] + self.selects = [] # some DBs do not like ORDER BY in the inner queries of a UNION, etc. - for s in selects: - s.order_by(None) + for n, s in enumerate(selects): + if len(s._order_by_clause): + s = s.order_by(None) + # unions group from left to right, so don't group first select + if n: + self.selects.append(s.self_group(self)) + else: + self.selects.append(s) - self.group_by(*kwargs.pop('group_by', [None])) - self.order_by(*kwargs.pop('order_by', [None])) - if len(kwargs): - raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys())) self._col_map = {} + _SelectBaseMixin.__init__(self, **kwargs) + name = property(lambda s:s.keyword + " statement") def self_group(self, against=None): @@ -2835,12 +2931,18 @@ class CompoundSelect(_SelectBaseMixin, FromClause): col.orig_set = colset return col + def _copy_internals(self): + self._clone_from_clause() + self._col_map = {} + self.selects = [s._clone() for s in self.selects] + for attr in ('_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, getattr(self, attr)._clone()) + def get_children(self, column_collections=True, **kwargs): return (column_collections and list(self.c) or []) + \ - [self.order_by_clause, self.group_by_clause] + list(self.selects) - def accept_visitor(self, visitor): - visitor.visit_compound_select(self) - + [self._order_by_clause, self._group_by_clause] + list(self.selects) + def _find_engine(self): for s in self.selects: e = s._find_engine() @@ -2855,160 +2957,287 @@ class Select(_SelectBaseMixin, FromClause): """ - def __init__(self, columns=None, whereclause=None, from_obj=[], - order_by=None, group_by=None, having=None, - use_labels=False, distinct=False, for_update=False, - engine=None, bind=None, limit=None, offset=None, scalar=False, - correlate=True): + def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs): """construct a Select object. The public constructor for Select is the [sqlalchemy.sql#select()] function; see that function for argument descriptions. """ - _SelectBaseMixin.__init__(self) - self.__froms = util.OrderedSet() - self.__hide_froms = util.Set([self]) - self.use_labels = use_labels - self.whereclause = None - self.having = None - self._bind = bind or engine - self.limit = limit - self.offset = offset - self.for_update = for_update - self.is_compound = False - - # indicates that this select statement should not expand its columns - # into the column clause of an enclosing select, and should instead - # act like a single scalar column - self.is_scalar = scalar - if scalar: - # allow corresponding_column to return None - self.orig_set = util.Set() - - # indicates if this select statement, as a subquery, should automatically correlate - # its FROM clause to that of an enclosing select, update, or delete statement. - # note that the "correlate" method can be used to explicitly add a value to be correlated. - self.should_correlate = correlate - - # indicates if this select statement is a subquery inside another query - self.is_subquery = False - - # indicates if this select statement is in the from clause of another query - self.is_selected_from = False - - # indicates if this select statement is a subquery as a criterion - # inside of a WHERE clause - self.is_where = False + + self._should_correlate = correlate + self._distinct = distinct - self.distinct = distinct self._raw_columns = [] - self.__correlated = {} - self.__correlator = Select._CorrelatedVisitor(self, False) - self.__wherecorrelator = Select._CorrelatedVisitor(self, True) - self.__fromvisitor = Select._FromVisitor(self) - - - self.order_by_clause = self.group_by_clause = None + self.__correlate = util.Set() + self._froms = util.OrderedSet() + self._whereclause = None + self._having = None + self._prefixes = [] if columns is not None: for c in columns: self.append_column(c) - if order_by: - order_by = util.to_list(order_by) - if group_by: - group_by = util.to_list(group_by) - self.order_by(*(order_by or [None])) - self.group_by(*(group_by or [None])) - for c in self.order_by_clause: - self.__correlator.traverse(c) - for c in self.group_by_clause: - self.__correlator.traverse(c) - - for f in from_obj: - self.append_from(f) - - # whereclauses must be appended after the columns/FROM, since it affects - # the correlation of subqueries. see test/sql/select.py SelectTest.testwheresubquery + if from_obj is not None: + for f in from_obj: + self.append_from(f) + if whereclause is not None: self.append_whereclause(whereclause) + if having is not None: self.append_having(having) + _SelectBaseMixin.__init__(self, **kwargs) - class _CorrelatedVisitor(NoColumnVisitor): - """Visit a clause, locate any ``Select`` clauses, and tell - them that they should correlate their ``FROM`` list to that of - their parent. + def _get_display_froms(self, correlation_state=None): + """return the full list of 'from' clauses to be displayed. + + takes into account an optional 'correlation_state' + dictionary which contains information about this Select's + correlation to an enclosing select, which may cause some 'from' + clauses to not display in this Select's FROM clause. + this dictionary is generated during compile time by the + _calculate_correlations() method. + """ + froms = util.OrderedSet() + hide_froms = util.Set() + + for col in self._raw_columns: + for f in col._hide_froms(): + hide_froms.add(f) + for f in col._get_from_objects(): + froms.add(f) + + if self._whereclause is not None: + for f in self._whereclause._get_from_objects(is_where=True): + froms.add(f) + + for elem in self._froms: + froms.add(elem) + for f in elem._get_from_objects(): + froms.add(f) + + for elem in froms: + for f in elem._hide_froms(): + hide_froms.add(f) + + froms = froms.difference(hide_froms) + + if len(froms) > 1: + corr = self.__correlate + if correlation_state is not None: + corr = correlation_state[self].get('correlate', util.Set()).union(corr) + f = froms.difference(corr) + if len(f) == 0: + raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) + return f + else: + return froms + + froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") + + def locate_all_froms(self): + froms = util.Set() + for col in self._raw_columns: + for f in col._get_from_objects(): + froms.add(f) + + if self._whereclause is not None: + for f in self._whereclause._get_from_objects(is_where=True): + froms.add(f) + + for elem in self._froms: + froms.add(elem) + for f in elem._get_from_objects(): + froms.add(f) + return froms + + def _calculate_correlations(self, correlation_state): + """generate a 'correlation_state' dictionary used by the _get_display_froms() method. + + The dictionary is passed in initially empty, or already + containing the state information added by an enclosing + Select construct. The method will traverse through all + embedded Select statements and add information about their + position and "from" objects to the dictionary. Those Select + statements will later consult the 'correlation_state' dictionary + when their list of 'FROM' clauses are generated using their + _get_display_froms() method. + """ + + if self not in correlation_state: + correlation_state[self] = {} - def __init__(self, select, is_where): - NoColumnVisitor.__init__(self) - self.select = select - self.is_where = is_where - - def visit_compound_select(self, cs): - self.visit_select(cs) - - def visit_column(self, c): - pass - - def visit_table(self, c): - pass - - def visit_select(self, select): - if select is self.select: - return - select.is_where = self.is_where - select.is_subquery = True - if not select.should_correlate: - return - [select.correlate(x) for x in self.select._Select__froms] + display_froms = self._get_display_froms(correlation_state) + + class CorrelatedVisitor(NoColumnVisitor): + def __init__(self, is_where=False, is_column=False, is_from=False): + self.is_where = is_where + self.is_column = is_column + self.is_from = is_from + + def visit_compound_select(self, cs): + self.visit_select(cs) - class _FromVisitor(NoColumnVisitor): - def __init__(self, select): - NoColumnVisitor.__init__(self) - self.select = select + def visit_select(s, select): + if select not in correlation_state: + correlation_state[select] = {} + + if select is self: + return + + select_state = correlation_state[select] + if s.is_from: + select_state['is_selected_from'] = True + if s.is_where: + select_state['is_where'] = True + select_state['is_subquery'] = True + + if select._should_correlate: + corr = select_state.setdefault('correlate', util.Set()) + # not crazy about this part. need to be clearer on what elements in the + # subquery correspond to elements in the enclosing query. + for f in display_froms: + corr.add(f) + for f2 in f._get_from_objects(): + corr.add(f2) + + col_vis = CorrelatedVisitor(is_column=True) + where_vis = CorrelatedVisitor(is_where=True) + from_vis = CorrelatedVisitor(is_from=True) + + for col in self._raw_columns: + col_vis.traverse(col) + for f in col._get_from_objects(): + if f is not self: + from_vis.traverse(f) + + for col in list(self._order_by_clause) + list(self._group_by_clause): + col_vis.traverse(col) - def visit_select(self, select): - if select is self.select: - return - select.is_selected_from = True - select.is_subquery = True + if self._whereclause is not None: + where_vis.traverse(self._whereclause) + for f in self._whereclause._get_from_objects(is_where=True): + if f is not self: + from_vis.traverse(f) + + for elem in self._froms: + from_vis.traverse(elem) + + def _get_inner_columns(self): + for c in self._raw_columns: + if isinstance(c, Selectable): + for co in c.columns: + yield co + else: + yield c + + inner_columns = property(_get_inner_columns) + + def _copy_internals(self): + self._clone_from_clause() + self._raw_columns = [c._clone() for c in self._raw_columns] + self._recorrelate_froms([(f, f._clone()) for f in self._froms]) + for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, getattr(self, attr)._clone()) + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.columns) or []) + \ + list(self._froms) + \ + [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] + + def _recorrelate_froms(self, froms): + newcorrelate = util.Set() + newfroms = util.Set() + oldfroms = util.Set(self._froms) + for old, new in froms: + if old in self.__correlate: + newcorrelate.add(new) + self.__correlate.remove(old) + if old in oldfroms: + newfroms.add(new) + oldfroms.remove(old) + self.__correlate = self.__correlate.union(newcorrelate) + self._froms = [f for f in oldfroms.union(newfroms)] + + def column(self, column): + s = self._generate() + s.append_column(column) + return s + + def where(self, whereclause): + s = self._generate() + s.append_whereclause(whereclause) + return s + + def having(self, having): + s = self._generate() + s.append_having(having) + return s + + def distinct(self): + s = self._generate() + s.distinct = True + return s + + def prefix_with(self, clause): + s = self._generate() + s.append_prefix(clause) + return s + + def select_from(self, fromclause): + s = self._generate() + s.append_from(fromclause) + return s + + def __dont_correlate(self): + s = self._generate() + s._should_correlate = False + return s + + def correlate(self, fromclause): + s = self._generate() + s._should_correlate=False + if fromclause is None: + s.__correlate = util.Set() + else: + s.append_correlation(fromclause) + return s + + def append_correlation(self, fromclause): + self.__correlate.add(fromclause) + def append_column(self, column): - if _is_literal(column): - column = literal_column(str(column)) + column = _literal_as_column(column) - if isinstance(column, Select) and column.is_scalar: - column = column.self_group(against=',') + if isinstance(column, _ScalarSelect): + column = column.self_group(against=ColumnOperators.comma_op) self._raw_columns.append(column) - - if self.is_scalar and not hasattr(self, 'type'): - self.type = column.type + + def append_prefix(self, clause): + clause = _literal_as_text(clause) + self._prefixes.append(clause) - # if the column is a Select statement itself, - # accept visitor - self.__correlator.traverse(column) - - # visit the FROM objects of the column looking for more Selects - for f in column._get_from_objects(): - if f is not self: - self.__correlator.traverse(f) - self._process_froms(column, False) - - def _make_proxy(self, selectable, name): - if self.is_scalar: - return self._raw_columns[0]._make_proxy(selectable, name) + def append_whereclause(self, whereclause): + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) else: - raise exceptions.InvalidRequestError("Not a scalar select statement") - - def label(self, name): - if not self.is_scalar: - raise exceptions.InvalidRequestError("Not a scalar select statement") + self._whereclause = _literal_as_text(whereclause) + + def append_having(self, having): + if self._having is not None: + self._having = and_(self._having, _literal_as_text(having)) else: - return label(name, self) + self._having = _literal_as_text(having) + + def append_from(self, fromclause): + if _is_literal(fromclause): + fromclause = FromClause(fromclause) + self._froms.add(fromclause) def _exportable_columns(self): return [c for c in self._raw_columns if isinstance(c, Selectable)] @@ -3019,53 +3248,13 @@ class Select(_SelectBaseMixin, FromClause): else: return column._make_proxy(self) - def _process_froms(self, elem, asfrom): - for f in elem._get_from_objects(): - self.__fromvisitor.traverse(f) - self.__froms.add(f) - if asfrom: - self.__froms.add(elem) - for f in elem._hide_froms(): - self.__hide_froms.add(f) - def self_group(self, against=None): if isinstance(against, CompoundSelect): return self return _Grouping(self) - - def append_whereclause(self, whereclause): - self._append_condition('whereclause', whereclause) - - def append_having(self, having): - self._append_condition('having', having) - - def _append_condition(self, attribute, condition): - if isinstance(condition, basestring): - condition = _TextClause(condition) - self.__wherecorrelator.traverse(condition) - self._process_froms(condition, False) - if getattr(self, attribute) is not None: - setattr(self, attribute, and_(getattr(self, attribute), condition)) - else: - setattr(self, attribute, condition) - - def correlate(self, from_obj): - """Given a ``FROM`` object, correlate this ``SELECT`` statement to it. - - This basically means the given from object will not come out - in this select statement's ``FROM`` clause when printed. - """ - - self.__correlated[from_obj] = from_obj - - def append_from(self, fromclause): - if isinstance(fromclause, basestring): - fromclause = FromClause(fromclause) - self.__correlator.traverse(fromclause) - self._process_froms(fromclause, True) def _locate_oid_column(self): - for f in self.__froms: + for f in self.locate_all_froms(): if f is self: # we might be in our own _froms list if a column with us as the parent is attached, # which includes textual columns. @@ -3076,25 +3265,6 @@ class Select(_SelectBaseMixin, FromClause): else: return None - def _calc_froms(self): - f = self.__froms.difference(self.__hide_froms) - if (len(f) > 1): - return f.difference(self.__correlated) - else: - return f - - froms = property(_calc_froms, - doc="""A collection containing all elements - of the ``FROM`` clause.""") - - def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.columns) or []) + \ - list(self.froms) + \ - [x for x in (self.whereclause, self.having, self.order_by_clause, self.group_by_clause) if x is not None] - - def accept_visitor(self, visitor): - visitor.visit_select(self) - def union(self, other, **kwargs): return union(self, other, **kwargs) @@ -3108,7 +3278,7 @@ class Select(_SelectBaseMixin, FromClause): if self._bind is not None: return self._bind - for f in self.__froms: + for f in self._froms: if f is self: continue e = f.bind @@ -3133,20 +3303,24 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True - class _SelectCorrelator(NoColumnVisitor): - def __init__(self, table): - NoColumnVisitor.__init__(self) - self.table = table - - def visit_select(self, select): - if select.should_correlate: - select.correlate(self.table) - - def _process_whereclause(self, whereclause): - if whereclause is not None: - _UpdateBase._SelectCorrelator(self.table).traverse(whereclause) - return whereclause + def _calculate_correlations(self, correlate_state): + class SelectCorrelator(NoColumnVisitor): + def visit_select(s, select): + if select._should_correlate: + select_state = correlate_state.setdefault(select, {}) + corr = select_state.setdefault('correlate', util.Set()) + corr.add(self.table) + + vis = SelectCorrelator() + if self._whereclause is not None: + vis.traverse(self._whereclause) + + if getattr(self, 'parameters', None) is not None: + for key, value in self.parameters.items(): + if isinstance(value, ClauseElement): + vis.traverse(value) + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. @@ -3155,7 +3329,7 @@ class _UpdateBase(ClauseElement): if parameters is None: return None - if isinstance(parameters, list) or isinstance(parameters, tuple): + if isinstance(parameters, (list, tuple)): pp = {} i = 0 for c in self.table.c: @@ -3163,11 +3337,10 @@ class _UpdateBase(ClauseElement): i +=1 parameters = pp - correlator = _UpdateBase._SelectCorrelator(self.table) for key in parameters.keys(): value = parameters[key] if isinstance(value, ClauseElement): - correlator.traverse(value) + parameters[key] = value.self_group() elif _is_literal(value): if _is_literal(key): col = self.table.c[key] @@ -3182,7 +3355,7 @@ class _UpdateBase(ClauseElement): def _find_engine(self): return self.table.bind -class _Insert(_UpdateBase): +class Insert(_UpdateBase): def __init__(self, table, values=None): self.table = table self.select = None @@ -3193,32 +3366,41 @@ class _Insert(_UpdateBase): return self.select, else: return () - def accept_visitor(self, visitor): - visitor.visit_insert(self) -class _Update(_UpdateBase): +class Update(_UpdateBase): def __init__(self, table, whereclause, values=None): self.table = table - self.whereclause = self._process_whereclause(whereclause) + self._whereclause = whereclause self.parameters = self._process_colparams(values) def get_children(self, **kwargs): - if self.whereclause is not None: - return self.whereclause, + if self._whereclause is not None: + return self._whereclause, else: return () - def accept_visitor(self, visitor): - visitor.visit_update(self) -class _Delete(_UpdateBase): +class Delete(_UpdateBase): def __init__(self, table, whereclause): self.table = table - self.whereclause = self._process_whereclause(whereclause) + self._whereclause = whereclause def get_children(self, **kwargs): - if self.whereclause is not None: - return self.whereclause, + if self._whereclause is not None: + return self._whereclause, else: return () - def accept_visitor(self, visitor): - visitor.visit_delete(self) + +class _IdentifiedClause(ClauseElement): + def __init__(self, ident): + self.ident = ident + def supports_execution(self): + return True + +class SavepointClause(_IdentifiedClause): + pass + +class RollbackToSavepointClause(_IdentifiedClause): + pass + +class ReleaseSavepointClause(_IdentifiedClause): + pass |