summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
commited4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch)
treec1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/sql.py
parent3a8e235af64e36b3b711df1f069d32359fe6c967 (diff)
downloadsqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r--lib/sqlalchemy/sql.py1946
1 files changed, 1064 insertions, 882 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 8b454947e..01588e92d 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -24,54 +24,21 @@ are less guaranteed to stay the same in future releases.
"""
-from sqlalchemy import util, exceptions, logging
+from sqlalchemy import util, exceptions
from sqlalchemy import types as sqltypes
-import string, re, random, sets
+import re, operator
-
-__all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
+__all__ = ['Alias', 'ClauseElement', 'ClauseParameters',
'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
- 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join',
- 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc',
- 'between_', 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
+ 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join',
+ 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
+ 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
'insert', 'intersect', 'intersect_all', 'join', 'literal',
- 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
+ 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select',
'subquery', 'table', 'text', 'union', 'union_all', 'update',]
-# precedence ordering for common operators. if an operator is not present in this list,
-# it will be parenthesized when grouped against other operators
-PRECEDENCE = {
- 'FROM':15,
- '*':7,
- '/':7,
- '%':7,
- '+':6,
- '-':6,
- 'ILIKE':5,
- 'NOT ILIKE':5,
- 'LIKE':5,
- 'NOT LIKE':5,
- 'IN':5,
- 'NOT IN':5,
- 'IS':5,
- 'IS NOT':5,
- '=':5,
- '!=':5,
- '>':5,
- '<':5,
- '>=':5,
- '<=':5,
- 'BETWEEN':5,
- 'NOT':4,
- 'AND':3,
- 'OR':2,
- ',':-1,
- 'AS':-1,
- 'EXISTS':0,
- '_smallest': -1000,
- '_largest': 1000
-}
+BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
def desc(column):
"""Return a descending ``ORDER BY`` clause element.
@@ -141,7 +108,7 @@ def join(left, right, onclause=None, **kwargs):
return Join(left, right, onclause, **kwargs)
-def select(columns=None, whereclause = None, from_obj = [], **kwargs):
+def select(columns=None, whereclause=None, from_obj=[], **kwargs):
"""Returns a ``SELECT`` clause element.
Similar functionality is also available via the ``select()`` method on any
@@ -224,9 +191,6 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
automatically bind to whatever ``Connectable`` instances can be located
within its contained ``ClauseElement`` members.
- engine=None
- deprecated. a synonym for "bind".
-
limit=None
a numerical value which usually compiles to a ``LIMIT`` expression
in the resulting select. Databases that don't support ``LIMIT``
@@ -238,12 +202,8 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
will attempt to provide similar functionality.
scalar=False
- when ``True``, indicates that the resulting ``Select`` object
- is to be used in the "columns" clause of another select statement,
- where the evaluated value of the column is the scalar result of
- this statement. Normally, placing any ``Selectable`` within the
- columns clause of a ``select()`` call will expand the member
- columns of the ``Selectable`` individually.
+ deprecated. use select(...).as_scalar() to create a "scalar column"
+ proxy for an existing Select object.
correlate=True
indicates that this ``Select`` object should have its contained
@@ -254,8 +214,12 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
rendered in the ``FROM`` clause of this select statement.
"""
-
- return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs)
+ scalar = kwargs.pop('scalar', False)
+ s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
+ if scalar:
+ return s.as_scalar()
+ else:
+ return s
def subquery(alias, *args, **kwargs):
"""Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select].
@@ -271,7 +235,7 @@ def subquery(alias, *args, **kwargs):
return Select(*args, **kwargs).alias(alias)
def insert(table, values = None, **kwargs):
- """Return an [sqlalchemy.sql#_Insert] clause element.
+ """Return an [sqlalchemy.sql#Insert] clause element.
Similar functionality is available via the ``insert()``
method on [sqlalchemy.schema#Table].
@@ -304,10 +268,10 @@ def insert(table, values = None, **kwargs):
against the ``INSERT`` statement.
"""
- return _Insert(table, values, **kwargs)
+ return Insert(table, values, **kwargs)
def update(table, whereclause = None, values = None, **kwargs):
- """Return an [sqlalchemy.sql#_Update] clause element.
+ """Return an [sqlalchemy.sql#Update] clause element.
Similar functionality is available via the ``update()``
method on [sqlalchemy.schema#Table].
@@ -344,10 +308,10 @@ def update(table, whereclause = None, values = None, **kwargs):
against the ``UPDATE`` statement.
"""
- return _Update(table, whereclause, values, **kwargs)
+ return Update(table, whereclause, values, **kwargs)
def delete(table, whereclause = None, **kwargs):
- """Return a [sqlalchemy.sql#_Delete] clause element.
+ """Return a [sqlalchemy.sql#Delete] clause element.
Similar functionality is available via the ``delete()``
method on [sqlalchemy.schema#Table].
@@ -361,7 +325,7 @@ def delete(table, whereclause = None, **kwargs):
"""
- return _Delete(table, whereclause, **kwargs)
+ return Delete(table, whereclause, **kwargs)
def and_(*clauses):
"""Join a list of clauses together using the ``AND`` operator.
@@ -371,7 +335,7 @@ def and_(*clauses):
"""
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='AND', *clauses)
+ return ClauseList(operator=operator.and_, *clauses)
def or_(*clauses):
"""Join a list of clauses together using the ``OR`` operator.
@@ -382,7 +346,7 @@ def or_(*clauses):
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='OR', *clauses)
+ return ClauseList(operator=operator.or_, *clauses)
def not_(clause):
"""Return a negation of the given clause, i.e. ``NOT(clause)``.
@@ -391,7 +355,7 @@ def not_(clause):
subclasses to produce the same result.
"""
- return clause._negate()
+ return operator.inv(clause)
def distinct(expr):
"""return a ``DISTINCT`` clause."""
@@ -407,12 +371,9 @@ def between(ctest, cleft, cright):
provides similar functionality.
"""
- return _BinaryExpression(ctest, ClauseList(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type), operator='AND', group=False), 'BETWEEN')
+ ctest = _literal_as_binds(ctest)
+ return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op)
-def between_(*args, **kwargs):
- """synonym for [sqlalchemy.sql#between()] (deprecated)."""
-
- return between(*args, **kwargs)
def case(whens, value=None, else_=None):
"""Produce a ``CASE`` statement.
@@ -435,7 +396,7 @@ def case(whens, value=None, else_=None):
type = list(whenlist[-1])[-1].type
else:
type = None
- cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END'])
+ cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END'])
return cc
def cast(clause, totype, **kwargs):
@@ -457,7 +418,7 @@ def cast(clause, totype, **kwargs):
def extract(field, expr):
"""Return the clause ``extract(field FROM expr)``."""
- expr = _BinaryExpression(text(field), expr, "FROM")
+ expr = _BinaryExpression(text(field), expr, Operators.from_)
return func.extract(expr)
def exists(*args, **kwargs):
@@ -587,7 +548,7 @@ def alias(selectable, alias=None):
return Alias(selectable, alias=alias)
-def literal(value, type=None):
+def literal(value, type_=None):
"""Return a literal clause, bound to a bind parameter.
Literal clauses are created automatically when non-
@@ -603,13 +564,13 @@ def literal(value, type=None):
the underlying DBAPI, or is translatable via the given type
argument.
- type
+ type\_
an optional [sqlalchemy.types#TypeEngine] which will provide
bind-parameter translation for this literal.
"""
- return _BindParamClause('literal', value, type=type, unique=True)
+ return _BindParamClause('literal', value, type_=type_, unique=True)
def label(name, obj):
"""Return a [sqlalchemy.sql#_Label] object for the given [sqlalchemy.sql#ColumnElement].
@@ -630,7 +591,7 @@ def label(name, obj):
return _Label(name, obj)
-def column(text, type=None):
+def column(text, type_=None):
"""Return a textual column clause, as would be in the columns
clause of a ``SELECT`` statement.
@@ -644,15 +605,15 @@ def column(text, type=None):
constructs that are not to be quoted, use the [sqlalchemy.sql#literal_column()]
function.
- type
+ type\_
an optional [sqlalchemy.types#TypeEngine] object which will provide
result-set translation for this column.
"""
- return _ColumnClause(text, type=type)
+ return _ColumnClause(text, type_=type_)
-def literal_column(text, type=None):
+def literal_column(text, type_=None):
"""Return a textual column clause, as would be in the columns
clause of a ``SELECT`` statement.
@@ -674,7 +635,7 @@ def literal_column(text, type=None):
"""
- return _ColumnClause(text, type=type, is_literal=True)
+ return _ColumnClause(text, type_=type_, is_literal=True)
def table(name, *columns):
"""Return a [sqlalchemy.sql#Table] object.
@@ -685,7 +646,7 @@ def table(name, *columns):
return TableClause(name, *columns)
-def bindparam(key, value=None, type=None, shortname=None, unique=False):
+def bindparam(key, value=None, type_=None, shortname=None, unique=False):
"""Create a bind parameter clause with the given key.
value
@@ -707,11 +668,22 @@ def bindparam(key, value=None, type=None, shortname=None, unique=False):
"""
if isinstance(key, _ColumnClause):
- return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique)
+ return _BindParamClause(key.name, value, type_=key.type, shortname=shortname, unique=unique)
else:
- return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique)
+ return _BindParamClause(key, value, type_=type_, shortname=shortname, unique=unique)
-def text(text, bind=None, engine=None, *args, **kwargs):
+def outparam(key, type_=None):
+ """create an 'OUT' parameter for usage in functions (stored procedures), for databases
+ whith support them.
+
+ The ``outparam`` can be used like a regular function parameter. The "output" value will
+ be available from the [sqlalchemy.engine#ResultProxy] object via its ``out_parameters``
+ attribute, which returns a dictionary containing the values.
+ """
+
+ return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True)
+
+def text(text, bind=None, *args, **kwargs):
"""Create literal text to be inserted into a query.
When constructing a query from a ``select()``, ``update()``,
@@ -729,9 +701,6 @@ def text(text, bind=None, engine=None, *args, **kwargs):
bind
An optional connection or engine to be used for this text query.
- engine
- deprecated. a synonym for 'bind'.
-
bindparams
A list of ``bindparam()`` instances which can be used to define
the types and/or initial values for the bind parameters within
@@ -748,7 +717,7 @@ def text(text, bind=None, engine=None, *args, **kwargs):
"""
- return _TextClause(text, engine=engine, bind=bind, *args, **kwargs)
+ return _TextClause(text, bind=bind, *args, **kwargs)
def null():
"""Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement."""
@@ -786,30 +755,44 @@ def _compound_select(keyword, *selects, **kwargs):
def _is_literal(element):
return not isinstance(element, ClauseElement)
-def _literals_as_text(element):
- if _is_literal(element):
+def _literal_as_text(element):
+ if isinstance(element, Operators):
+ return element.expression_element()
+ elif _is_literal(element):
return _TextClause(unicode(element))
else:
return element
-def _literals_as_binds(element, name='literal', type=None):
- if _is_literal(element):
+def _literal_as_column(element):
+ if isinstance(element, Operators):
+ return element.clause_element()
+ elif _is_literal(element):
+ return literal_column(str(element))
+ else:
+ return element
+
+def _literal_as_binds(element, name='literal', type_=None):
+ if isinstance(element, Operators):
+ return element.expression_element()
+ elif _is_literal(element):
if element is None:
return null()
else:
- return _BindParamClause(name, element, shortname=name, type=type, unique=True)
+ return _BindParamClause(name, element, shortname=name, type_=type_, unique=True)
else:
return element
+
+def _selectable(element):
+ if hasattr(element, '__selectable__'):
+ return element.__selectable__()
+ elif isinstance(element, Selectable):
+ return element
+ else:
+ raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
def is_column(col):
return isinstance(col, ColumnElement)
-class AbstractDialect(object):
- """Represent the behavior of a particular database.
-
- Used by ``Compiled`` objects."""
- pass
-
class ClauseParameters(object):
"""Represent a dictionary/iterator of bind parameter key names/values.
@@ -822,52 +805,51 @@ class ClauseParameters(object):
def __init__(self, dialect, positional=None):
super(ClauseParameters, self).__init__()
self.dialect = dialect
- self.binds = {}
- self.binds_to_names = {}
- self.binds_to_values = {}
+ self.__binds = {}
self.positional = positional or []
+ def get_parameter(self, key):
+ return self.__binds[key]
+
def set_parameter(self, bindparam, value, name):
- self.binds[bindparam.key] = bindparam
- self.binds[name] = bindparam
- self.binds_to_names[bindparam] = name
- self.binds_to_values[bindparam] = value
+ self.__binds[name] = [bindparam, name, value]
def get_original(self, key):
- """Return the given parameter as it was originally placed in
- this ``ClauseParameters`` object, without any ``Type``
- conversion."""
- return self.binds_to_values[self.binds[key]]
+ return self.__binds[key][2]
+
+ def get_type(self, key):
+ return self.__binds[key][0].type
def get_processed(self, key):
- bind = self.binds[key]
- value = self.binds_to_values[bind]
+ (bind, name, value) = self.__binds[key]
return bind.typeprocess(value, self.dialect)
def keys(self):
- return self.binds_to_names.values()
+ return self.__binds.keys()
+
+ def __iter__(self):
+ return iter(self.keys())
def __getitem__(self, key):
return self.get_processed(key)
def __contains__(self, key):
- return key in self.binds
+ return key in self.__binds
def set_value(self, key, value):
- bind = self.binds[key]
- self.binds_to_values[bind] = value
+ self.__binds[key][2] = value
def get_original_dict(self):
- return dict([(self.binds_to_names[b], self.binds_to_values[b]) for b in self.binds_to_names.keys()])
+ return dict([(name, value) for (b, name, value) in self.__binds.values()])
def get_raw_list(self):
return [self.get_processed(key) for key in self.positional]
- def get_raw_dict(self):
- d = {}
- for k in self.binds_to_names.values():
- d[k] = self.get_processed(k)
- return d
+ def get_raw_dict(self, encode_keys=False):
+ if encode_keys:
+ return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()])
+ else:
+ return dict([(key, self.get_processed(key)) for key in self.keys()])
def __repr__(self):
return self.__class__.__name__ + ":" + repr(self.get_original_dict())
@@ -876,8 +858,8 @@ class ClauseVisitor(object):
"""A class that knows how to traverse and visit
``ClauseElements``.
- Each ``ClauseElement``'s accept_visitor() method will call a
- corresponding visit_XXXX() method here. Traversal of a
+ Calls visit_XXX() methods dynamically generated for each particualr
+ ``ClauseElement`` subclass encountered. Traversal of a
hierarchy of ``ClauseElements`` is achieved via the
``traverse()`` method, which is passed the lead
``ClauseElement``.
@@ -889,22 +871,44 @@ class ClauseVisitor(object):
these options can indicate modifications to the set of
elements returned, such as to not return column collections
(column_collections=False) or to return Schema-level items
- (schema_visitor=True)."""
+ (schema_visitor=True).
+
+ ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
+ operation, which will produce a copy of a given ``ClauseElement``
+ structure while at the same time allowing ``ClauseVisitor`` subclasses
+ to modify the new structure in-place.
+
+ """
__traverse_options__ = {}
- def traverse(self, obj, stop_on=None):
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- for target in traversal:
- v = self
- while v is not None:
- target.accept_visitor(v)
- v = getattr(v, '_next', None)
+
+ def traverse_single(self, obj, **kwargs):
+ meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj, **kwargs)
+
+ def traverse(self, obj, stop_on=None, clone=False):
+ if clone:
+ obj = obj._clone()
+
+ v = self
+ visitors = []
+ while v is not None:
+ visitors.append(v)
+ v = getattr(v, '_next', None)
+
+ def _trav(obj):
+ if stop_on is not None and obj in stop_on:
+ return
+ if clone:
+ obj._copy_internals()
+ for c in obj.get_children(**self.__traverse_options__):
+ _trav(c)
+
+ for v in visitors:
+ meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ meth(obj)
+ _trav(obj)
return obj
def chain(self, visitor):
@@ -916,78 +920,6 @@ class ClauseVisitor(object):
tail = tail._next
tail._next = visitor
return self
-
- def visit_column(self, column):
- pass
- def visit_table(self, table):
- pass
- def visit_fromclause(self, fromclause):
- pass
- def visit_bindparam(self, bindparam):
- pass
- def visit_textclause(self, textclause):
- pass
- def visit_compound(self, compound):
- pass
- def visit_compound_select(self, compound):
- pass
- def visit_binary(self, binary):
- pass
- def visit_unary(self, unary):
- pass
- def visit_alias(self, alias):
- pass
- def visit_select(self, select):
- pass
- def visit_join(self, join):
- pass
- def visit_null(self, null):
- pass
- def visit_clauselist(self, list):
- pass
- def visit_calculatedclause(self, calcclause):
- pass
- def visit_grouping(self, gr):
- pass
- def visit_function(self, func):
- pass
- def visit_cast(self, cast):
- pass
- def visit_label(self, label):
- pass
- def visit_typeclause(self, typeclause):
- pass
-
-class LoggingClauseVisitor(ClauseVisitor):
- """extends ClauseVisitor to include debug logging of all traversal.
-
- To install this visitor, set logging.DEBUG for
- 'sqlalchemy.sql.ClauseVisitor' **before** you import the
- sqlalchemy.sql module.
- """
-
- def traverse(self, obj, stop_on=None):
- stack = [(obj, "")]
- traversal = []
- while len(stack) > 0:
- (t, indent) = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, (t, indent))
- for c in t.get_children(**self.__traverse_options__):
- stack.append((c, indent + " "))
-
- for (target, indent) in traversal:
- self.logger.debug(indent + repr(target))
- v = self
- while v is not None:
- target.accept_visitor(v)
- v = getattr(v, '_next', None)
- return obj
-
-LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor)
-
-if logging.is_debug_enabled(LoggingClauseVisitor.logger):
- ClauseVisitor=LoggingClauseVisitor
class NoColumnVisitor(ClauseVisitor):
"""a ClauseVisitor that will not traverse the exported Column
@@ -1000,113 +932,35 @@ class NoColumnVisitor(ClauseVisitor):
"""
__traverse_options__ = {'column_collections':False}
-
-class Executor(object):
- """Interface representing a "thing that can produce Compiled objects
- and execute them"."""
-
- def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
- """Execute a Compiled object."""
-
- raise NotImplementedError()
-
- def compiler(self, statement, parameters, **kwargs):
- """Return a Compiled object for the given statement and parameters."""
-
- raise NotImplementedError()
-
-class Compiled(ClauseVisitor):
- """Represent a compiled SQL expression.
-
- The ``__str__`` method of the ``Compiled`` object should produce
- the actual text of the statement. ``Compiled`` objects are
- specific to their underlying database dialect, and also may
- or may not be specific to the columns referenced within a
- particular set of bind parameters. In no case should the
- ``Compiled`` object be dependent on the actual values of those
- bind parameters, even though it may reference those values as
- defaults.
- """
-
- def __init__(self, dialect, statement, parameters, bind=None, engine=None):
- """Construct a new ``Compiled`` object.
-
- statement
- ``ClauseElement`` to be compiled.
-
- parameters
- Optional dictionary indicating a set of bind parameters
- specified with this ``Compiled`` object. These parameters
- are the *default* values corresponding to the
- ``ClauseElement``'s ``_BindParamClauses`` when the
- ``Compiled`` is executed. In the case of an ``INSERT`` or
- ``UPDATE`` statement, these parameters will also result in
- the creation of new ``_BindParamClause`` objects for each
- key and will also affect the generated column list in an
- ``INSERT`` statement and the ``SET`` clauses of an
- ``UPDATE`` statement. The keys of the parameter dictionary
- can either be the string names of columns or
- ``_ColumnClause`` objects.
-
- bind
- optional engine or connection which will be bound to the
- compiled object.
-
- engine
- deprecated, a synonym for 'bind'
- """
- self.dialect = dialect
- self.statement = statement
- self.parameters = parameters
- self.bind = bind or engine
- self.can_execute = statement.supports_execution()
-
- def compile(self):
- self.traverse(self.statement)
- self.after_compile()
-
- def __str__(self):
- """Return the string text of the generated SQL statement."""
-
- raise NotImplementedError()
- def get_params(self, **params):
- """Deprecated. use construct_params(). (supports unicode names)
- """
- return self.construct_params(params)
-
- def construct_params(self, params):
- """Return the bind params for this compiled object.
-
- Will start with the default parameters specified when this
- ``Compiled`` object was first constructed, and will override
- those values with those sent via `**params`, which are
- key/value pairs. Each key should match one of the
- ``_BindParamClause`` objects compiled into this object; either
- the `key` or `shortname` property of the ``_BindParamClause``.
- """
- raise NotImplementedError()
+class _FigureVisitName(type):
+ def __init__(cls, clsname, bases, dict):
+ if not '__visit_name__' in cls.__dict__:
+ m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+ x = m.group(1)
+ x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+ cls.__visit_name__ = x.lower()
+ super(_FigureVisitName, cls).__init__(clsname, bases, dict)
- def execute(self, *multiparams, **params):
- """Execute this compiled object."""
-
- e = self.bind
- if e is None:
- raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.")
- return e.execute_compiled(self, *multiparams, **params)
-
- def scalar(self, *multiparams, **params):
- """Execute this compiled object and return the result's scalar value."""
-
- return self.execute(*multiparams, **params).scalar()
-
class ClauseElement(object):
"""Base class for elements of a programmatically constructed SQL
expression.
"""
+ __metaclass__ = _FigureVisitName
+
+ def _clone(self):
+ """create a shallow copy of this ClauseElement.
+
+ This method may be used by a generative API.
+ Its also used as part of the "deep" copy afforded
+ by a traversal that combines the _copy_internals()
+ method."""
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = self.__dict__.copy()
+ return c
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
"""Return objects represented in this ``ClauseElement`` that
should be added to the ``FROM`` list of a query, when this
``ClauseElement`` is placed in the column clause of a
@@ -1115,7 +969,7 @@ class ClauseElement(object):
raise NotImplementedError(repr(self))
- def _hide_froms(self):
+ def _hide_froms(self, **modifiers):
"""Return a list of ``FROM`` clause elements which this
``ClauseElement`` replaces.
"""
@@ -1131,13 +985,14 @@ class ClauseElement(object):
return self is other
- def accept_visitor(self, visitor):
- """Accept a ``ClauseVisitor`` and call the appropriate
- ``visit_xxx`` method.
- """
-
- raise NotImplementedError(repr(self))
-
+ def _copy_internals(self):
+ """reassign internal elements to be clones of themselves.
+
+ called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy."""
+
+ pass
+
def get_children(self, **kwargs):
"""return immediate child elements of this ``ClauseElement``.
@@ -1160,18 +1015,6 @@ class ClauseElement(object):
return False
- def copy_container(self):
- """Return a copy of this ``ClauseElement``, if this
- ``ClauseElement`` contains other ``ClauseElements``.
-
- If this ``ClauseElement`` is not a container, it should return
- self. This is used to create copies of expression trees that
- still reference the same *leaf nodes*. The new structure can
- then be restructured without affecting the original.
- """
-
- return self
-
def _find_engine(self):
"""Default strategy for locating an engine within the clause element.
@@ -1195,7 +1038,6 @@ class ClauseElement(object):
return None
bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""")
- engine = bind
def execute(self, *multiparams, **params):
"""Compile and execute this ``ClauseElement``."""
@@ -1213,7 +1055,7 @@ class ClauseElement(object):
return self.execute(*multiparams, **params).scalar()
- def compile(self, bind=None, engine=None, parameters=None, compiler=None, dialect=None):
+ def compile(self, bind=None, parameters=None, compiler=None, dialect=None):
"""Compile this SQL expression.
Uses the given ``Compiler``, or the given ``AbstractDialect``
@@ -1236,7 +1078,7 @@ class ClauseElement(object):
``SET`` and ``VALUES`` clause of those statements.
"""
- if (isinstance(parameters, list) or isinstance(parameters, tuple)):
+ if isinstance(parameters, (list, tuple)):
parameters = parameters[0]
if compiler is None:
@@ -1244,8 +1086,6 @@ class ClauseElement(object):
compiler = dialect.compiler(self, parameters)
elif bind is not None:
compiler = bind.compiler(self, parameters)
- elif engine is not None:
- compiler = engine.compiler(self, parameters)
elif self.bind is not None:
compiler = self.bind.compiler(self, parameters)
@@ -1268,49 +1108,257 @@ class ClauseElement(object):
return self._negate()
def _negate(self):
- return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
+ if hasattr(self, 'negation_clause'):
+ return self.negation_clause
+ else:
+ return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None)
+
-class _CompareMixin(object):
- """Defines comparison operations for ``ClauseElement`` instances.
+class Operators(object):
+ def from_():
+ raise NotImplementedError()
+ from_ = staticmethod(from_)
- This is a mixin class that adds the capability to produce ``ClauseElement``
- instances based on regular Python operators.
- These operations are achieved using Python's operator overload methods
- (i.e. ``__eq__()``, ``__ne__()``, etc.
+ def as_():
+ raise NotImplementedError()
+ as_ = staticmethod(as_)
- Overridden operators include all comparison operators (i.e. '==', '!=', '<'),
- math operators ('+', '-', '*', etc), the '&' and '|' operators which evaluate
- to ``AND`` and ``OR`` respectively.
+ def exists():
+ raise NotImplementedError()
+ exists = staticmethod(exists)
- Other methods exist to create additional SQL clauses such as ``IN``, ``LIKE``,
- ``DISTINCT``, etc.
+ def is_():
+ raise NotImplementedError()
+ is_ = staticmethod(is_)
- """
+ def isnot():
+ raise NotImplementedError()
+ isnot = staticmethod(isnot)
+
+ def __and__(self, other):
+ return self.operate(operator.and_, other)
+
+ def __or__(self, other):
+ return self.operate(operator.or_, other)
+
+ def __invert__(self):
+ return self.operate(operator.inv)
+
+ def clause_element(self):
+ raise NotImplementedError()
+
+ def operate(self, op, *other, **kwargs):
+ raise NotImplementedError()
+
+ def reverse_operate(self, op, *other, **kwargs):
+ raise NotImplementedError()
+
+class ColumnOperators(Operators):
+ """defines comparison and math operations"""
+
+ def like_op(a, b):
+ return a.like(b)
+ like_op = staticmethod(like_op)
+
+ def notlike_op(a, b):
+ raise NotImplementedError()
+ notlike_op = staticmethod(notlike_op)
+ def ilike_op(a, b):
+ return a.ilike(b)
+ ilike_op = staticmethod(ilike_op)
+
+ def notilike_op(a, b):
+ raise NotImplementedError()
+ notilike_op = staticmethod(notilike_op)
+
+ def between_op(a, b):
+ return a.between(b)
+ between_op = staticmethod(between_op)
+
+ def in_op(a, b):
+ return a.in_(*b)
+ in_op = staticmethod(in_op)
+
+ def notin_op(a, b):
+ raise NotImplementedError()
+ notin_op = staticmethod(notin_op)
+
+ def startswith_op(a, b):
+ return a.startswith(b)
+ startswith_op = staticmethod(startswith_op)
+
+ def endswith_op(a, b):
+ return a.endswith(b)
+ endswith_op = staticmethod(endswith_op)
+
+ def comma_op(a, b):
+ raise NotImplementedError()
+ comma_op = staticmethod(comma_op)
+
+ def concat_op(a, b):
+ return a.concat(b)
+ concat_op = staticmethod(concat_op)
+
def __lt__(self, other):
- return self._compare('<', other)
+ return self.operate(operator.lt, other)
def __le__(self, other):
- return self._compare('<=', other)
+ return self.operate(operator.le, other)
def __eq__(self, other):
- return self._compare('=', other)
+ return self.operate(operator.eq, other)
def __ne__(self, other):
- return self._compare('!=', other)
+ return self.operate(operator.ne, other)
def __gt__(self, other):
- return self._compare('>', other)
+ return self.operate(operator.gt, other)
def __ge__(self, other):
- return self._compare('>=', other)
+ return self.operate(operator.ge, other)
+ def concat(self, other):
+ return self.operate(ColumnOperators.concat_op, other)
+
def like(self, other):
- """produce a ``LIKE`` clause."""
- return self._compare('LIKE', other)
+ return self.operate(ColumnOperators.like_op, other)
+
+ def in_(self, *other):
+ return self.operate(ColumnOperators.in_op, other)
+
+ def startswith(self, other):
+ return self.operate(ColumnOperators.startswith_op, other)
+
+ def endswith(self, other):
+ return self.operate(ColumnOperators.endswith_op, other)
+
+ def __radd__(self, other):
+ return self.reverse_operate(operator.add, other)
+
+ def __rsub__(self, other):
+ return self.reverse_operate(operator.sub, other)
+
+ def __rmul__(self, other):
+ return self.reverse_operate(operator.mul, other)
+
+ def __rdiv__(self, other):
+ return self.reverse_operate(operator.div, other)
+
+ def between(self, cleft, cright):
+ return self.operate(Operators.between_op, (cleft, cright))
+
+ def __add__(self, other):
+ return self.operate(operator.add, other)
+
+ def __sub__(self, other):
+ return self.operate(operator.sub, other)
+
+ def __mul__(self, other):
+ return self.operate(operator.mul, other)
+
+ def __div__(self, other):
+ return self.operate(operator.div, other)
+
+ def __mod__(self, other):
+ return self.operate(operator.mod, other)
+
+ def __truediv__(self, other):
+ return self.operate(operator.truediv, other)
+
+# precedence ordering for common operators. if an operator is not present in this list,
+# it will be parenthesized when grouped against other operators
+_smallest = object()
+_largest = object()
+
+PRECEDENCE = {
+ Operators.from_:15,
+ operator.mul:7,
+ operator.div:7,
+ operator.mod:7,
+ operator.add:6,
+ operator.sub:6,
+ ColumnOperators.concat_op:6,
+ ColumnOperators.ilike_op:5,
+ ColumnOperators.notilike_op:5,
+ ColumnOperators.like_op:5,
+ ColumnOperators.notlike_op:5,
+ ColumnOperators.in_op:5,
+ ColumnOperators.notin_op:5,
+ Operators.is_:5,
+ Operators.isnot:5,
+ operator.eq:5,
+ operator.ne:5,
+ operator.gt:5,
+ operator.lt:5,
+ operator.ge:5,
+ operator.le:5,
+ ColumnOperators.between_op:5,
+ operator.inv:4,
+ operator.and_:3,
+ operator.or_:2,
+ ColumnOperators.comma_op:-1,
+ Operators.as_:-1,
+ Operators.exists:0,
+ _smallest: -1000,
+ _largest: 1000
+}
+
+class _CompareMixin(ColumnOperators):
+ """Defines comparison and math operations for ``ClauseElement`` instances."""
+
+ def __compare(self, op, obj, negate=None):
+ if obj is None or isinstance(obj, _Null):
+ if op == operator.eq:
+ return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot)
+ elif op == operator.ne:
+ return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_)
+ else:
+ raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+ else:
+ obj = self._check_literal(obj)
+
+
+ return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate)
+
+ def __operate(self, op, obj):
+ obj = self._check_literal(obj)
+
+ type_ = self._compare_type(obj)
+
+ # TODO: generalize operator overloading like this out into the types module
+ if op == operator.add and isinstance(type_, (sqltypes.Concatenable)):
+ op = ColumnOperators.concat_op
+
+ return _BinaryExpression(self.expression_element(), obj, op, type_=type_)
+
+ operators = {
+ operator.add : (__operate,),
+ operator.mul : (__operate,),
+ operator.sub : (__operate,),
+ operator.div : (__operate,),
+ operator.mod : (__operate,),
+ operator.truediv : (__operate,),
+ operator.lt : (__compare, operator.ge),
+ operator.le : (__compare, operator.gt),
+ operator.ne : (__compare, operator.eq),
+ operator.gt : (__compare, operator.le),
+ operator.ge : (__compare, operator.lt),
+ operator.eq : (__compare, operator.ne),
+ ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op),
+ }
+
+ def operate(self, op, other):
+ o = _CompareMixin.operators[op]
+ return o[0](self, op, other, *o[1:])
+
+ def reverse_operate(self, op, other):
+ return self._bind_param(other).operate(op, self)
def in_(self, *other):
- """produce an ``IN`` clause."""
+ return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other)
+
+ def _in_impl(self, op, negate_op, *other):
if len(other) == 0:
return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1')))
elif len(other) == 1:
@@ -1318,8 +1366,8 @@ class _CompareMixin(object):
if _is_literal(o) or isinstance( o, _CompareMixin):
return self.__eq__( o) #single item -> ==
else:
- assert hasattr( o, '_selectable') #better check?
- return self._compare( 'IN', o, negate='NOT IN') #single selectable
+ assert isinstance(o, Selectable)
+ return self.__compare( op, o, negate=negate_op) #single selectable
args = []
for o in other:
@@ -1329,29 +1377,22 @@ class _CompareMixin(object):
else:
o = self._bind_param(o)
args.append(o)
- return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
+ return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op)
def startswith(self, other):
"""produce the clause ``LIKE '<other>%'``"""
- perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String)
- return self._compare('LIKE', other + perc)
+
+ perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String)
+ return self.__compare(ColumnOperators.like_op, other + perc)
def endswith(self, other):
"""produce the clause ``LIKE '%<other>'``"""
+
if isinstance(other,(str,unicode)): po = '%' + other
else:
- po = literal('%', type= sqltypes.String) + other
- po.type = sqltypes.to_instance( sqltypes.String) #force!
- return self._compare('LIKE', po)
-
- def __radd__(self, other):
- return self._bind_param(other)._operate('+', self)
- def __rsub__(self, other):
- return self._bind_param(other)._operate('-', self)
- def __rmul__(self, other):
- return self._bind_param(other)._operate('*', self)
- def __rdiv__(self, other):
- return self._bind_param(other)._operate('/', self)
+ po = literal('%', type_=sqltypes.String) + other
+ po.type = sqltypes.to_instance(sqltypes.String) #force!
+ return self.__compare(ColumnOperators.like_op, po)
def label(self, name):
"""produce a column label, i.e. ``<columnname> AS <name>``"""
@@ -1363,7 +1404,8 @@ class _CompareMixin(object):
def between(self, cleft, cright):
"""produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``"""
- return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator='AND', group=False), 'BETWEEN')
+
+ return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op)
def op(self, operator):
"""produce a generic operator function.
@@ -1382,59 +1424,25 @@ class _CompareMixin(object):
passed to the generated function.
"""
- return lambda other: self._operate(operator, other)
-
- # and here come the math operators:
-
- def __add__(self, other):
- return self._operate('+', other)
-
- def __sub__(self, other):
- return self._operate('-', other)
-
- def __mul__(self, other):
- return self._operate('*', other)
-
- def __div__(self, other):
- return self._operate('/', other)
-
- def __mod__(self, other):
- return self._operate('%', other)
-
- def __truediv__(self, other):
- return self._operate('/', other)
+ return lambda other: self.__operate(operator, other)
def _bind_param(self, obj):
- return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True)
+ return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True)
def _check_literal(self, other):
- if _is_literal(other):
+ if isinstance(other, Operators):
+ return other.expression_element()
+ elif _is_literal(other):
return self._bind_param(other)
else:
return other
+
+ def clause_element(self):
+ """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``."""
+ return self
- def _compare(self, operator, obj, negate=None):
- if obj is None or isinstance(obj, _Null):
- if operator == '=':
- return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT')
- elif operator == '!=':
- return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS')
- else:
- raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
- else:
- obj = self._check_literal(obj)
-
- return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate)
-
- def _operate(self, operator, obj):
- if _is_literal(obj):
- obj = self._bind_param(obj)
- return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj))
-
- def _compare_self(self):
- """Allow ``ColumnImpl`` to return its ``Column`` object for
- usage in ``ClauseElements``, all others to just return self.
- """
+ def expression_element(self):
+ """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
return self
@@ -1460,23 +1468,10 @@ class Selectable(ClauseElement):
columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""")
- def _selectable(self):
- return self
-
- def accept_visitor(self, visitor):
- raise NotImplementedError(repr(self))
-
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
- def _group_parenthesized(self):
- """Indicate if this ``Selectable`` requires parenthesis when
- grouped into a compound statement.
- """
- return True
-
-
class ColumnElement(Selectable, _CompareMixin):
"""Represent an element that is useable within the
"column clause" portion of a ``SELECT`` statement.
@@ -1616,8 +1611,10 @@ class ColumnCollection(util.OrderedProperties):
l.append(c==local)
return and_(*l)
- def __contains__(self, col):
- return self.contains_column(col)
+ def __contains__(self, other):
+ if not isinstance(other, basestring):
+ raise exceptions.ArgumentError("__contains__ requires a string argument")
+ return self.has_key(other)
def contains_column(self, col):
# have to use a Set here, because it will compare the identity
@@ -1649,19 +1646,18 @@ class FromClause(Selectable):
clause of a ``SELECT`` statement.
"""
+ __visit_name__ = 'fromclause'
+
def __init__(self, name=None):
self.name = name
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
# this could also be [self], at the moment it doesnt matter to the Select object
return []
def default_order_by(self):
return [self.oid_column]
- def accept_visitor(self, visitor):
- visitor.visit_fromclause(self)
-
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
@@ -1703,6 +1699,13 @@ class FromClause(Selectable):
FindCols().traverse(self)
return ret
+ def is_derived_from(self, fromclause):
+ """return True if this FromClause is 'derived' from the given FromClause.
+
+ An example would be an Alias of a Table is derived from that Table."""
+
+ return False
+
def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
"""Given a ``ColumnElement``, return the exported
``ColumnElement`` object from this ``Selectable`` which
@@ -1730,9 +1733,10 @@ class FromClause(Selectable):
it merely shares a common anscestor with one of
the exported columns of this ``FromClause``.
"""
- if column in self.c:
+
+ if self.c.contains_column(column):
return column
-
+
if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
if not raiseerr:
return None
@@ -1761,6 +1765,15 @@ class FromClause(Selectable):
self._export_columns()
return getattr(self, name)
+ def _clone_from_clause(self):
+ # delete all the "generated" collections of columns for a newly cloned FromClause,
+ # so that they will be re-derived from the item.
+ # this is because FromClause subclasses, when cloned, need to reestablish new "proxied"
+ # columns that are linked to the new item
+ for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'):
+ if hasattr(self, attr):
+ delattr(self, attr)
+
columns = property(lambda s:s._get_exported_attribute('_columns'))
c = property(lambda s:s._get_exported_attribute('_columns'))
primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
@@ -1791,8 +1804,9 @@ class FromClause(Selectable):
self._primary_key = ColumnSet()
self._foreign_keys = util.Set()
self._orig_cols = {}
+
if columns is None:
- columns = self._adjusted_exportable_columns()
+ columns = self._flatten_exportable_columns()
for co in columns:
cp = self._proxy_column(co)
for ci in cp.orig_set:
@@ -1806,13 +1820,14 @@ class FromClause(Selectable):
for ci in self.oid_column.orig_set:
self._orig_cols[ci] = self.oid_column
- def _adjusted_exportable_columns(self):
+ def _flatten_exportable_columns(self):
"""return the list of ColumnElements represented within this FromClause's _exportable_columns"""
export = self._exportable_columns()
for column in export:
- try:
- s = column._selectable()
- except AttributeError:
+ # TODO: is this conditional needed ?
+ if isinstance(column, Selectable):
+ s = column
+ else:
continue
for co in s.columns:
yield co
@@ -1829,7 +1844,9 @@ class _BindParamClause(ClauseElement, _CompareMixin):
Public constructor is the ``bindparam()`` function.
"""
- def __init__(self, key, value, shortname=None, type=None, unique=False):
+ __visit_name__ = 'bindparam'
+
+ def __init__(self, key, value, shortname=None, type_=None, unique=False, isoutparam=False):
"""Construct a _BindParamClause.
key
@@ -1852,7 +1869,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
execution may match either the key or the shortname of the
corresponding ``_BindParamClause`` objects.
- type
+ type\_
A ``TypeEngine`` object that will be used to pre-process the
value corresponding to this ``_BindParamClause`` at
execution time.
@@ -1862,23 +1879,34 @@ class _BindParamClause(ClauseElement, _CompareMixin):
modified if another ``_BindParamClause`` of the same
name already has been located within the containing
``ClauseElement``.
+
+ isoutparam
+ if True, the parameter should be treated like a stored procedure "OUT"
+ parameter.
"""
- self.key = key
+ self.key = key or "{ANON %d param}" % id(self)
self.value = value
self.shortname = shortname or key
self.unique = unique
- self.type = sqltypes.to_instance(type)
-
- def accept_visitor(self, visitor):
- visitor.visit_bindparam(self)
-
- def _get_from_objects(self):
+ self.isoutparam = isoutparam
+ type_ = sqltypes.to_instance(type_)
+ if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map:
+ self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)])
+ else:
+ self.type = type_
+
+ # TODO: move to types module, obviously
+ type_map = {
+ str : sqltypes.String,
+ unicode : sqltypes.Unicode,
+ int : sqltypes.Integer,
+ float : sqltypes.Numeric
+ }
+
+ def _get_from_objects(self, **modifiers):
return []
- def copy_container(self):
- return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique)
-
def typeprocess(self, value, dialect):
return self.type.dialect_impl(dialect).convert_bind_param(value, dialect)
@@ -1893,7 +1921,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__
def __repr__(self):
- return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type))
+ return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type))
class _TypeClause(ClauseElement):
"""Handle a type keyword in a SQL statement.
@@ -1901,13 +1929,12 @@ class _TypeClause(ClauseElement):
Used by the ``Case`` statement.
"""
+ __visit_name__ = 'typeclause'
+
def __init__(self, type):
self.type = type
- def accept_visitor(self, visitor):
- visitor.visit_typeclause(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
class _TextClause(ClauseElement):
@@ -1916,8 +1943,10 @@ class _TextClause(ClauseElement):
Public constructor is the ``text()`` function.
"""
- def __init__(self, text = "", bind=None, engine=None, bindparams=None, typemap=None):
- self._bind = bind or engine
+ __visit_name__ = 'textclause'
+
+ def __init__(self, text = "", bind=None, bindparams=None, typemap=None):
+ self._bind = bind
self.bindparams = {}
self.typemap = typemap
if typemap is not None:
@@ -1930,7 +1959,7 @@ class _TextClause(ClauseElement):
# scan the string and search for bind parameter names, add them
# to the list of bindparams
- self.text = re.compile(r'(?<!:):([\w_]+)', re.S).sub(repl, text)
+ self.text = BIND_PARAMS.sub(repl, text)
if bindparams is not None:
for b in bindparams:
self.bindparams[b.key] = b
@@ -1944,13 +1973,13 @@ class _TextClause(ClauseElement):
columns = property(lambda s:[])
+ def _copy_internals(self):
+ self.bindparams = [b._clone() for b in self.bindparams]
+
def get_children(self, **kwargs):
return self.bindparams.values()
- def accept_visitor(self, visitor):
- visitor.visit_textclause(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
def supports_execution(self):
@@ -1965,10 +1994,7 @@ class _Null(ColumnElement):
def __init__(self):
self.type = sqltypes.NULLTYPE
- def accept_visitor(self, visitor):
- visitor.visit_null(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
class ClauseList(ClauseElement):
@@ -1976,14 +2002,16 @@ class ClauseList(ClauseElement):
By default, is comma-separated, such as a column listing.
"""
-
+ __visit_name__ = 'clauselist'
+
def __init__(self, *clauses, **kwargs):
self.clauses = []
- self.operator = kwargs.pop('operator', ',')
+ self.operator = kwargs.pop('operator', ColumnOperators.comma_op)
self.group = kwargs.pop('group', True)
self.group_contents = kwargs.pop('group_contents', True)
for c in clauses:
- if c is None: continue
+ if c is None:
+ continue
self.append(c)
def __iter__(self):
@@ -1991,32 +2019,28 @@ class ClauseList(ClauseElement):
def __len__(self):
return len(self.clauses)
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return ClauseList(operator=self.operator, *clauses)
-
def append(self, clause):
# TODO: not sure if i like the 'group_contents' flag. need to define the difference between
# a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ?
if self.group_contents:
- self.clauses.append(_literals_as_text(clause).self_group(against=self.operator))
+ self.clauses.append(_literal_as_text(clause).self_group(against=self.operator))
else:
- self.clauses.append(_literals_as_text(clause))
+ self.clauses.append(_literal_as_text(clause))
+
+ def _copy_internals(self):
+ self.clauses = [clause._clone() for clause in self.clauses]
def get_children(self, **kwargs):
return self.clauses
- def accept_visitor(self, visitor):
- visitor.visit_clauselist(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
f = []
for c in self.clauses:
- f += c._get_from_objects()
+ f += c._get_from_objects(**modifiers)
return f
def self_group(self, against=None):
- if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+ if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
return _Grouping(self)
else:
return self
@@ -2043,40 +2067,45 @@ class _CalculatedClause(ColumnElement):
Extends ``ColumnElement`` to provide column-level comparison
operators.
"""
-
+ __visit_name__ = 'calculatedclause'
+
def __init__(self, name, *clauses, **kwargs):
self.name = name
- self.type = sqltypes.to_instance(kwargs.get('type', None))
- self._bind = kwargs.get('bind', kwargs.get('engine', None))
+ self.type = sqltypes.to_instance(kwargs.get('type_', None))
+ self._bind = kwargs.get('bind', None)
self.group = kwargs.pop('group', True)
- self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
+ clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
if self.group:
- self.clause_expr = self.clauses.self_group()
+ self.clause_expr = clauses.self_group()
else:
- self.clause_expr = self.clauses
+ self.clause_expr = clauses
key = property(lambda self:self.name or "_calc_")
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return _CalculatedClause(type=self.type, bind=self._bind, *clauses)
-
+ def _copy_internals(self):
+ self.clause_expr = self.clause_expr._clone()
+
+ def clauses(self):
+ if isinstance(self.clause_expr, _Grouping):
+ return self.clause_expr.elem
+ else:
+ return self.clause_expr
+ clauses = property(clauses)
+
def get_children(self, **kwargs):
return self.clause_expr,
- def accept_visitor(self, visitor):
- visitor.visit_calculatedclause(self)
- def _get_from_objects(self):
- return self.clauses._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.clauses._get_from_objects(**modifiers)
def _bind_param(self, obj):
- return _BindParamClause(self.name, obj, type=self.type, unique=True)
+ return _BindParamClause(self.name, obj, type_=self.type, unique=True)
def select(self):
return select([self])
def scalar(self):
- return select([self]).scalar()
+ return select([self]).execute().scalar()
def execute(self):
return select([self]).execute()
@@ -2092,28 +2121,26 @@ class _Function(_CalculatedClause, FromClause):
"""
def __init__(self, name, *clauses, **kwargs):
- self.type = sqltypes.to_instance(kwargs.get('type', None))
self.packagenames = kwargs.get('packagenames', None) or []
- kwargs['operator'] = ','
- self._engine = kwargs.get('engine', None)
+ kwargs['operator'] = ColumnOperators.comma_op
_CalculatedClause.__init__(self, name, **kwargs)
for c in clauses:
self.append(c)
key = property(lambda self:self.name)
+ def _copy_internals(self):
+ _CalculatedClause._copy_internals(self)
+ self._clone_from_clause()
- def append(self, clause):
- self.clauses.append(_literals_as_binds(clause, self.name))
-
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return _Function(self.name, type=self.type, packagenames=self.packagenames, bind=self._bind, *clauses)
+ def get_children(self, **kwargs):
+ return _CalculatedClause.get_children(self, **kwargs)
- def accept_visitor(self, visitor):
- visitor.visit_function(self)
+ def append(self, clause):
+ self.clauses.append(_literal_as_binds(clause, self.name))
class _Cast(ColumnElement):
+
def __init__(self, clause, totype, **kwargs):
if not hasattr(clause, 'label'):
clause = literal(clause)
@@ -2122,17 +2149,19 @@ class _Cast(ColumnElement):
self.typeclause = _TypeClause(self.type)
self._distance = 0
+ def _copy_internals(self):
+ self.clause = self.clause._clone()
+ self.typeclause = self.typeclause._clone()
+
def get_children(self, **kwargs):
return self.clause, self.typeclause
- def accept_visitor(self, visitor):
- visitor.visit_cast(self)
- def _get_from_objects(self):
- return self.clause._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.clause._get_from_objects(**modifiers)
def _make_proxy(self, selectable, name=None):
if name is not None:
- co = _ColumnClause(name, selectable, type=self.type)
+ co = _ColumnClause(name, selectable, type_=self.type)
co._distance = self._distance + 1
co.orig_set = self.orig_set
selectable.columns[name]= co
@@ -2142,26 +2171,23 @@ class _Cast(ColumnElement):
class _UnaryExpression(ColumnElement):
- def __init__(self, element, operator=None, modifier=None, type=None, negate=None):
+ def __init__(self, element, operator=None, modifier=None, type_=None, negate=None):
self.operator = operator
self.modifier = modifier
- self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier)
- self.type = sqltypes.to_instance(type)
+ self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier)
+ self.type = sqltypes.to_instance(type_)
self.negate = negate
- def copy_container(self):
- return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate)
+ def _get_from_objects(self, **modifiers):
+ return self.element._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.element._get_from_objects()
+ def _copy_internals(self):
+ self.element = self.element._clone()
def get_children(self, **kwargs):
return self.element,
- def accept_visitor(self, visitor):
- visitor.visit_unary(self)
-
def compare(self, other):
"""Compare this ``_UnaryExpression`` against the given ``ClauseElement``."""
@@ -2170,14 +2196,15 @@ class _UnaryExpression(ColumnElement):
self.modifier == other.modifier and
self.element.compare(other.element)
)
+
def _negate(self):
if self.negate is not None:
- return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
+ return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type)
else:
return super(_UnaryExpression, self)._negate()
def self_group(self, against):
- if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+ if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
return _Grouping(self)
else:
return self
@@ -2186,25 +2213,23 @@ class _UnaryExpression(ColumnElement):
class _BinaryExpression(ColumnElement):
"""Represent an expression that is ``LEFT <operator> RIGHT``."""
- def __init__(self, left, right, operator, type=None, negate=None):
- self.left = _literals_as_text(left).self_group(against=operator)
- self.right = _literals_as_text(right).self_group(against=operator)
+ def __init__(self, left, right, operator, type_=None, negate=None):
+ self.left = _literal_as_text(left).self_group(against=operator)
+ self.right = _literal_as_text(right).self_group(against=operator)
self.operator = operator
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self.negate = negate
- def copy_container(self):
- return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator)
+ def _get_from_objects(self, **modifiers):
+ return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.left._get_from_objects() + self.right._get_from_objects()
+ def _copy_internals(self):
+ self.left = self.left._clone()
+ self.right = self.right._clone()
def get_children(self, **kwargs):
return self.left, self.right
- def accept_visitor(self, visitor):
- visitor.visit_binary(self)
-
def compare(self, other):
"""Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``."""
@@ -2213,7 +2238,7 @@ class _BinaryExpression(ColumnElement):
(
self.left.compare(other.left) and self.right.compare(other.right)
or (
- self.operator in ['=', '!=', '+', '*'] and
+ self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and
self.left.compare(other.right) and self.right.compare(other.left)
)
)
@@ -2221,25 +2246,27 @@ class _BinaryExpression(ColumnElement):
def self_group(self, against=None):
# use small/large defaults for comparison so that unknown operators are always parenthesized
- if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])):
+ if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])):
return _Grouping(self)
else:
return self
def _negate(self):
if self.negate is not None:
- return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type)
+ return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type)
else:
return super(_BinaryExpression, self)._negate()
class _Exists(_UnaryExpression):
+ __visit_name__ = _UnaryExpression.__visit_name__
+
def __init__(self, *args, **kwargs):
kwargs['correlate'] = True
s = select(*args, **kwargs).self_group()
- _UnaryExpression.__init__(self, s, operator="EXISTS")
+ _UnaryExpression.__init__(self, s, operator=Operators.exists)
- def _hide_froms(self):
- return self._get_from_objects()
+ def _hide_froms(self, **modifiers):
+ return self._get_from_objects(**modifiers)
class Join(FromClause):
"""represent a ``JOIN`` construct between two ``FromClause``
@@ -2251,8 +2278,8 @@ class Join(FromClause):
"""
def __init__(self, left, right, onclause=None, isouter = False):
- self.left = left._selectable()
- self.right = right._selectable()
+ self.left = _selectable(left)
+ self.right = _selectable(right).self_group()
if onclause is None:
self.onclause = self._match_primaries(self.left, self.right)
else:
@@ -2265,8 +2292,8 @@ class Join(FromClause):
encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace'))
def _init_primary_key(self):
- pkcol = util.Set([c for c in self._adjusted_exportable_columns() if c.primary_key])
-
+ pkcol = util.Set([c for c in self._flatten_exportable_columns() if c.primary_key])
+
equivs = {}
def add_equiv(a, b):
for x, y in ((a, b), (b, a)):
@@ -2277,7 +2304,7 @@ class Join(FromClause):
class BinaryVisitor(ClauseVisitor):
def visit_binary(self, binary):
- if binary.operator == '=':
+ if binary.operator == operator.eq:
add_equiv(binary.left, binary.right)
BinaryVisitor().traverse(self.onclause)
@@ -2294,9 +2321,12 @@ class Join(FromClause):
omit.add(p)
p = c
- self.__primary_key = ColumnSet([c for c in self._adjusted_exportable_columns() if c.primary_key and c not in omit])
+ self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit])
primary_key = property(lambda s:s.__primary_key)
+
+ def self_group(self, against=None):
+ return _Grouping(self)
def _locate_oid_column(self):
return self.left.oid_column
@@ -2310,6 +2340,17 @@ class Join(FromClause):
self._foreign_keys.add(f)
return column
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self.left = self.left._clone()
+ self.right = self.right._clone()
+ self.onclause = self.onclause._clone()
+ self.__folded_equivalents = None
+ self._init_primary_key()
+
+ def get_children(self, **kwargs):
+ return self.left, self.right, self.onclause
+
def _match_primaries(self, primary, secondary):
crit = []
constraints = util.Set()
@@ -2338,9 +2379,6 @@ class Join(FromClause):
else:
return and_(*crit)
- def _group_parenthesized(self):
- return True
-
def _get_folded_equivalents(self, equivs=None):
if self.__folded_equivalents is not None:
return self.__folded_equivalents
@@ -2348,7 +2386,7 @@ class Join(FromClause):
equivs = util.Set()
class LocateEquivs(NoColumnVisitor):
def visit_binary(self, binary):
- if binary.operator == '=' and binary.left.name == binary.right.name:
+ if binary.operator == operator.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
equivs.add(binary.left)
LocateEquivs().traverse(self.onclause)
@@ -2401,13 +2439,7 @@ class Join(FromClause):
return select(collist, whereclause, from_obj=[self], **kwargs)
- def get_children(self, **kwargs):
- return self.left, self.right, self.onclause
-
- def accept_visitor(self, visitor):
- visitor.visit_join(self)
-
- engine = property(lambda s:s.left.engine or s.right.engine)
+ bind = property(lambda s:s.left.bind or s.right.bind)
def alias(self, name=None):
"""Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it.
@@ -2417,11 +2449,11 @@ class Join(FromClause):
return self.select(use_labels=True, correlate=False).alias(name)
- def _hide_froms(self):
- return self.left._get_from_objects() + self.right._get_from_objects()
+ def _hide_froms(self, **modifiers):
+ return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
class Alias(FromClause):
"""represent an alias, as typically applied to any
@@ -2443,15 +2475,22 @@ class Alias(FromClause):
if alias is None:
if self.original.named_with_column():
alias = getattr(self.original, 'name', None)
- if alias is None:
- alias = 'anon'
- elif len(alias) > 15:
- alias = alias[0:15]
- alias = alias + "_" + hex(random.randint(0, 65535))[2:]
+ alias = '{ANON %d %s}' % (id(self), alias or 'anon')
self.name = alias
self.encodedname = alias.encode('ascii', 'backslashreplace')
self.case_sensitive = getattr(baseselectable, "case_sensitive", True)
+ def is_derived_from(self, fromclause):
+ x = self.selectable
+ while True:
+ if x is fromclause:
+ return True
+ if isinstance(x, Alias):
+ x = x.selectable
+ else:
+ break
+ return False
+
def supports_execution(self):
return self.original.supports_execution()
@@ -2468,46 +2507,57 @@ class Alias(FromClause):
#return self.selectable._exportable_columns()
return self.selectable.columns
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self.selectable = self.selectable._clone()
+ baseselectable = self.selectable
+ while isinstance(baseselectable, Alias):
+ baseselectable = baseselectable.selectable
+ self.original = baseselectable
+
def get_children(self, **kwargs):
for c in self.c:
yield c
yield self.selectable
- def accept_visitor(self, visitor):
- visitor.visit_alias(self)
-
def _get_from_objects(self):
return [self]
- def _group_parenthesized(self):
- return False
-
bind = property(lambda s: s.selectable.bind)
- engine = bind
-class _Grouping(ColumnElement):
+class _ColumnElementAdapter(ColumnElement):
+ """adapts a ClauseElement which may or may not be a
+ ColumnElement subclass itself into an object which
+ acts like a ColumnElement.
+ """
+
def __init__(self, elem):
self.elem = elem
self.type = getattr(elem, 'type', None)
-
+ self.orig_set = getattr(elem, 'orig_set', util.Set())
+
key = property(lambda s: s.elem.key)
_label = property(lambda s: s.elem._label)
- orig_set = property(lambda s:s.elem.orig_set)
-
- def copy_container(self):
- return _Grouping(self.elem.copy_container())
-
- def accept_visitor(self, visitor):
- visitor.visit_grouping(self)
+ columns = c = property(lambda s:s.elem.columns)
+
+ def _copy_internals(self):
+ self.elem = self.elem._clone()
+
def get_children(self, **kwargs):
return self.elem,
- def _hide_froms(self):
- return self.elem._hide_froms()
- def _get_from_objects(self):
- return self.elem._get_from_objects()
+
+ def _hide_froms(self, **modifiers):
+ return self.elem._hide_froms(**modifiers)
+
+ def _get_from_objects(self, **modifiers):
+ return self.elem._get_from_objects(**modifiers)
+
def __getattr__(self, attr):
return getattr(self.elem, attr)
-
+
+class _Grouping(_ColumnElementAdapter):
+ pass
+
class _Label(ColumnElement):
"""represent a label, as typically applied to any column-level element
using the ``AS`` sql keyword.
@@ -2518,32 +2568,33 @@ class _Label(ColumnElement):
"""
- def __init__(self, name, obj, type=None):
- self.name = name
+ def __init__(self, name, obj, type_=None):
while isinstance(obj, _Label):
obj = obj.obj
- self.obj = obj.self_group(against='AS')
+ self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
+
+ self.obj = obj.self_group(against=Operators.as_)
self.case_sensitive = getattr(obj, "case_sensitive", True)
- self.type = sqltypes.to_instance(type or getattr(obj, 'type', None))
+ self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
key = property(lambda s: s.name)
_label = property(lambda s: s.name)
orig_set = property(lambda s:s.obj.orig_set)
- def _compare_self(self):
+ def expression_element(self):
return self.obj
-
+
+ def _copy_internals(self):
+ self.obj = self.obj._clone()
+
def get_children(self, **kwargs):
return self.obj,
- def accept_visitor(self, visitor):
- visitor.visit_label(self)
+ def _get_from_objects(self, **modifiers):
+ return self.obj._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.obj._get_from_objects()
-
- def _hide_froms(self):
- return self.obj._hide_froms()
+ def _hide_froms(self, **modifiers):
+ return self.obj._hide_froms(**modifiers)
def _make_proxy(self, selectable, name = None):
if isinstance(self.obj, Selectable):
@@ -2551,8 +2602,6 @@ class _Label(ColumnElement):
else:
return column(self.name)._make_proxy(selectable=selectable)
-legal_characters = util.Set(string.ascii_letters + string.digits + '_')
-
class _ColumnClause(ColumnElement):
"""Represents a generic column expression from any textual string.
This includes columns associated with tables, aliases and select
@@ -2584,17 +2633,21 @@ class _ColumnClause(ColumnElement):
"""
- def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False):
+ def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False):
self.key = self.name = text
self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name
self.table = selectable
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self._is_oid = _is_oid
self._distance = 0
self.__label = None
self.case_sensitive = case_sensitive
self.is_literal = is_literal
-
+
+ def _clone(self):
+ # ColumnClause is immutable
+ return self
+
def _get_label(self):
"""Generate a 'label' for this column.
@@ -2617,7 +2670,6 @@ class _ColumnClause(ColumnElement):
counter += 1
else:
self.__label = self.name
- self.__label = "".join([x for x in self.__label if x in legal_characters])
return self.__label
is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name)
@@ -2632,23 +2684,20 @@ class _ColumnClause(ColumnElement):
else:
return super(_ColumnClause, self).label(name)
- def accept_visitor(self, visitor):
- visitor.visit_column(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
if self.table is not None:
return [self.table]
else:
return []
def _bind_param(self, obj):
- return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True)
+ return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True)
def _make_proxy(self, selectable, name = None):
# propigate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
- c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal)
+ c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
c.orig_set = self.orig_set
c._distance = self._distance + 1
if not self._is_oid:
@@ -2658,9 +2707,6 @@ class _ColumnClause(ColumnElement):
def _compare_type(self, obj):
return self.type
- def _group_parenthesized(self):
- return False
-
class TableClause(FromClause):
"""represents a "table" construct.
@@ -2677,6 +2723,10 @@ class TableClause(FromClause):
self._oid_column = _ColumnClause('oid', self, _is_oid=True)
self._export_columns(columns)
+ def _clone(self):
+ # TableClause is immutable
+ return self
+
def named_with_column(self):
return True
@@ -2709,15 +2759,9 @@ class TableClause(FromClause):
else:
return []
- def accept_visitor(self, visitor):
- visitor.visit_table(self)
-
def _exportable_columns(self):
raise NotImplementedError()
- def _group_parenthesized(self):
- return False
-
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
@@ -2746,68 +2790,120 @@ class TableClause(FromClause):
def delete(self, whereclause = None):
return delete(self, whereclause)
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return [self]
+
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
+ def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None):
+ self.use_labels = use_labels
+ self.for_update = for_update
+ self._limit = limit
+ self._offset = offset
+ self._bind = bind
+
+ self.append_order_by(*util.to_list(order_by, []))
+ self.append_group_by(*util.to_list(group_by, []))
+
+ def as_scalar(self):
+ return _ScalarSelect(self)
+
+ def label(self, name):
+ return self.as_scalar().label(name)
+
def supports_execution(self):
return True
+ def _generate(self):
+ s = self._clone()
+ s._clone_from_clause()
+ return s
+
+ def limit(self, limit):
+ s = self._generate()
+ s._limit = limit
+ return s
+
+ def offset(self, offset):
+ s = self._generate()
+ s._offset = offset
+ return s
+
def order_by(self, *clauses):
- if len(clauses) == 1 and clauses[0] is None:
- self.order_by_clause = ClauseList()
- elif getattr(self, 'order_by_clause', None):
- self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses)))
- else:
- self.order_by_clause = ClauseList(*clauses)
+ s = self._generate()
+ s.append_order_by(*clauses)
+ return s
def group_by(self, *clauses):
- if len(clauses) == 1 and clauses[0] is None:
- self.group_by_clause = ClauseList()
- elif getattr(self, 'group_by_clause', None):
- self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses)))
+ s = self._generate()
+ s.append_group_by(*clauses)
+ return s
+
+ def append_order_by(self, *clauses):
+ if clauses == [None]:
+ self._order_by_clause = ClauseList()
else:
- self.group_by_clause = ClauseList(*clauses)
+ if getattr(self, '_order_by_clause', None):
+ clauses = list(self._order_by_clause) + list(clauses)
+ self._order_by_clause = ClauseList(*clauses)
+ def append_group_by(self, *clauses):
+ if clauses == [None]:
+ self._group_by_clause = ClauseList()
+ else:
+ if getattr(self, '_group_by_clause', None):
+ clauses = list(self._group_by_clause) + list(clauses)
+ self._group_by_clause = ClauseList(*clauses)
+
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
- def _get_from_objects(self):
- if self.is_where or self.is_scalar:
+ def _get_from_objects(self, is_where=False, **modifiers):
+ if is_where:
return []
else:
return [self]
+class _ScalarSelect(_Grouping):
+ __visit_name__ = 'grouping'
+
+ def __init__(self, elem):
+ super(_ScalarSelect, self).__init__(elem)
+ self.type = list(elem.inner_columns)[0].type
+
+ columns = property(lambda self:[self])
+
+ def self_group(self, **kwargs):
+ return self
+
+ def _make_proxy(self, selectable, name):
+ return list(self.inner_columns)[0]._make_proxy(selectable, name)
+
+ def _get_from_objects(self, **modifiers):
+ return []
+
class CompoundSelect(_SelectBaseMixin, FromClause):
def __init__(self, keyword, *selects, **kwargs):
- _SelectBaseMixin.__init__(self)
+ self._should_correlate = kwargs.pop('correlate', False)
self.keyword = keyword
- self.use_labels = kwargs.pop('use_labels', False)
- self.should_correlate = kwargs.pop('correlate', False)
- self.for_update = kwargs.pop('for_update', False)
- self.nowait = kwargs.pop('nowait', False)
- self.limit = kwargs.pop('limit', None)
- self.offset = kwargs.pop('offset', None)
- self.is_compound = True
- self.is_where = False
- self.is_scalar = False
- self.is_subquery = False
-
- # unions group from left to right, so don't group first select
- self.selects = [n and select.self_group(self) or select for n,select in enumerate(selects)]
+ self.selects = []
# some DBs do not like ORDER BY in the inner queries of a UNION, etc.
- for s in selects:
- s.order_by(None)
+ for n, s in enumerate(selects):
+ if len(s._order_by_clause):
+ s = s.order_by(None)
+ # unions group from left to right, so don't group first select
+ if n:
+ self.selects.append(s.self_group(self))
+ else:
+ self.selects.append(s)
- self.group_by(*kwargs.pop('group_by', [None]))
- self.order_by(*kwargs.pop('order_by', [None]))
- if len(kwargs):
- raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys()))
self._col_map = {}
+ _SelectBaseMixin.__init__(self, **kwargs)
+
name = property(lambda s:s.keyword + " statement")
def self_group(self, against=None):
@@ -2835,12 +2931,18 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
col.orig_set = colset
return col
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self._col_map = {}
+ self.selects = [s._clone() for s in self.selects]
+ for attr in ('_order_by_clause', '_group_by_clause'):
+ if getattr(self, attr) is not None:
+ setattr(self, attr, getattr(self, attr)._clone())
+
def get_children(self, column_collections=True, **kwargs):
return (column_collections and list(self.c) or []) + \
- [self.order_by_clause, self.group_by_clause] + list(self.selects)
- def accept_visitor(self, visitor):
- visitor.visit_compound_select(self)
-
+ [self._order_by_clause, self._group_by_clause] + list(self.selects)
+
def _find_engine(self):
for s in self.selects:
e = s._find_engine()
@@ -2855,160 +2957,287 @@ class Select(_SelectBaseMixin, FromClause):
"""
- def __init__(self, columns=None, whereclause=None, from_obj=[],
- order_by=None, group_by=None, having=None,
- use_labels=False, distinct=False, for_update=False,
- engine=None, bind=None, limit=None, offset=None, scalar=False,
- correlate=True):
+ def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs):
"""construct a Select object.
The public constructor for Select is the [sqlalchemy.sql#select()] function;
see that function for argument descriptions.
"""
- _SelectBaseMixin.__init__(self)
- self.__froms = util.OrderedSet()
- self.__hide_froms = util.Set([self])
- self.use_labels = use_labels
- self.whereclause = None
- self.having = None
- self._bind = bind or engine
- self.limit = limit
- self.offset = offset
- self.for_update = for_update
- self.is_compound = False
-
- # indicates that this select statement should not expand its columns
- # into the column clause of an enclosing select, and should instead
- # act like a single scalar column
- self.is_scalar = scalar
- if scalar:
- # allow corresponding_column to return None
- self.orig_set = util.Set()
-
- # indicates if this select statement, as a subquery, should automatically correlate
- # its FROM clause to that of an enclosing select, update, or delete statement.
- # note that the "correlate" method can be used to explicitly add a value to be correlated.
- self.should_correlate = correlate
-
- # indicates if this select statement is a subquery inside another query
- self.is_subquery = False
-
- # indicates if this select statement is in the from clause of another query
- self.is_selected_from = False
-
- # indicates if this select statement is a subquery as a criterion
- # inside of a WHERE clause
- self.is_where = False
+
+ self._should_correlate = correlate
+ self._distinct = distinct
- self.distinct = distinct
self._raw_columns = []
- self.__correlated = {}
- self.__correlator = Select._CorrelatedVisitor(self, False)
- self.__wherecorrelator = Select._CorrelatedVisitor(self, True)
- self.__fromvisitor = Select._FromVisitor(self)
-
-
- self.order_by_clause = self.group_by_clause = None
+ self.__correlate = util.Set()
+ self._froms = util.OrderedSet()
+ self._whereclause = None
+ self._having = None
+ self._prefixes = []
if columns is not None:
for c in columns:
self.append_column(c)
- if order_by:
- order_by = util.to_list(order_by)
- if group_by:
- group_by = util.to_list(group_by)
- self.order_by(*(order_by or [None]))
- self.group_by(*(group_by or [None]))
- for c in self.order_by_clause:
- self.__correlator.traverse(c)
- for c in self.group_by_clause:
- self.__correlator.traverse(c)
-
- for f in from_obj:
- self.append_from(f)
-
- # whereclauses must be appended after the columns/FROM, since it affects
- # the correlation of subqueries. see test/sql/select.py SelectTest.testwheresubquery
+ if from_obj is not None:
+ for f in from_obj:
+ self.append_from(f)
+
if whereclause is not None:
self.append_whereclause(whereclause)
+
if having is not None:
self.append_having(having)
+ _SelectBaseMixin.__init__(self, **kwargs)
- class _CorrelatedVisitor(NoColumnVisitor):
- """Visit a clause, locate any ``Select`` clauses, and tell
- them that they should correlate their ``FROM`` list to that of
- their parent.
+ def _get_display_froms(self, correlation_state=None):
+ """return the full list of 'from' clauses to be displayed.
+
+ takes into account an optional 'correlation_state'
+ dictionary which contains information about this Select's
+ correlation to an enclosing select, which may cause some 'from'
+ clauses to not display in this Select's FROM clause.
+ this dictionary is generated during compile time by the
+ _calculate_correlations() method.
+
"""
+ froms = util.OrderedSet()
+ hide_froms = util.Set()
+
+ for col in self._raw_columns:
+ for f in col._hide_froms():
+ hide_froms.add(f)
+ for f in col._get_from_objects():
+ froms.add(f)
+
+ if self._whereclause is not None:
+ for f in self._whereclause._get_from_objects(is_where=True):
+ froms.add(f)
+
+ for elem in self._froms:
+ froms.add(elem)
+ for f in elem._get_from_objects():
+ froms.add(f)
+
+ for elem in froms:
+ for f in elem._hide_froms():
+ hide_froms.add(f)
+
+ froms = froms.difference(hide_froms)
+
+ if len(froms) > 1:
+ corr = self.__correlate
+ if correlation_state is not None:
+ corr = correlation_state[self].get('correlate', util.Set()).union(corr)
+ f = froms.difference(corr)
+ if len(f) == 0:
+ raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
+ return f
+ else:
+ return froms
+
+ froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
+
+ def locate_all_froms(self):
+ froms = util.Set()
+ for col in self._raw_columns:
+ for f in col._get_from_objects():
+ froms.add(f)
+
+ if self._whereclause is not None:
+ for f in self._whereclause._get_from_objects(is_where=True):
+ froms.add(f)
+
+ for elem in self._froms:
+ froms.add(elem)
+ for f in elem._get_from_objects():
+ froms.add(f)
+ return froms
+
+ def _calculate_correlations(self, correlation_state):
+ """generate a 'correlation_state' dictionary used by the _get_display_froms() method.
+
+ The dictionary is passed in initially empty, or already
+ containing the state information added by an enclosing
+ Select construct. The method will traverse through all
+ embedded Select statements and add information about their
+ position and "from" objects to the dictionary. Those Select
+ statements will later consult the 'correlation_state' dictionary
+ when their list of 'FROM' clauses are generated using their
+ _get_display_froms() method.
+ """
+
+ if self not in correlation_state:
+ correlation_state[self] = {}
- def __init__(self, select, is_where):
- NoColumnVisitor.__init__(self)
- self.select = select
- self.is_where = is_where
-
- def visit_compound_select(self, cs):
- self.visit_select(cs)
-
- def visit_column(self, c):
- pass
-
- def visit_table(self, c):
- pass
-
- def visit_select(self, select):
- if select is self.select:
- return
- select.is_where = self.is_where
- select.is_subquery = True
- if not select.should_correlate:
- return
- [select.correlate(x) for x in self.select._Select__froms]
+ display_froms = self._get_display_froms(correlation_state)
+
+ class CorrelatedVisitor(NoColumnVisitor):
+ def __init__(self, is_where=False, is_column=False, is_from=False):
+ self.is_where = is_where
+ self.is_column = is_column
+ self.is_from = is_from
+
+ def visit_compound_select(self, cs):
+ self.visit_select(cs)
- class _FromVisitor(NoColumnVisitor):
- def __init__(self, select):
- NoColumnVisitor.__init__(self)
- self.select = select
+ def visit_select(s, select):
+ if select not in correlation_state:
+ correlation_state[select] = {}
+
+ if select is self:
+ return
+
+ select_state = correlation_state[select]
+ if s.is_from:
+ select_state['is_selected_from'] = True
+ if s.is_where:
+ select_state['is_where'] = True
+ select_state['is_subquery'] = True
+
+ if select._should_correlate:
+ corr = select_state.setdefault('correlate', util.Set())
+ # not crazy about this part. need to be clearer on what elements in the
+ # subquery correspond to elements in the enclosing query.
+ for f in display_froms:
+ corr.add(f)
+ for f2 in f._get_from_objects():
+ corr.add(f2)
+
+ col_vis = CorrelatedVisitor(is_column=True)
+ where_vis = CorrelatedVisitor(is_where=True)
+ from_vis = CorrelatedVisitor(is_from=True)
+
+ for col in self._raw_columns:
+ col_vis.traverse(col)
+ for f in col._get_from_objects():
+ if f is not self:
+ from_vis.traverse(f)
+
+ for col in list(self._order_by_clause) + list(self._group_by_clause):
+ col_vis.traverse(col)
- def visit_select(self, select):
- if select is self.select:
- return
- select.is_selected_from = True
- select.is_subquery = True
+ if self._whereclause is not None:
+ where_vis.traverse(self._whereclause)
+ for f in self._whereclause._get_from_objects(is_where=True):
+ if f is not self:
+ from_vis.traverse(f)
+
+ for elem in self._froms:
+ from_vis.traverse(elem)
+
+ def _get_inner_columns(self):
+ for c in self._raw_columns:
+ if isinstance(c, Selectable):
+ for co in c.columns:
+ yield co
+ else:
+ yield c
+
+ inner_columns = property(_get_inner_columns)
+
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self._raw_columns = [c._clone() for c in self._raw_columns]
+ self._recorrelate_froms([(f, f._clone()) for f in self._froms])
+ for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
+ if getattr(self, attr) is not None:
+ setattr(self, attr, getattr(self, attr)._clone())
+ def get_children(self, column_collections=True, **kwargs):
+ return (column_collections and list(self.columns) or []) + \
+ list(self._froms) + \
+ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+
+ def _recorrelate_froms(self, froms):
+ newcorrelate = util.Set()
+ newfroms = util.Set()
+ oldfroms = util.Set(self._froms)
+ for old, new in froms:
+ if old in self.__correlate:
+ newcorrelate.add(new)
+ self.__correlate.remove(old)
+ if old in oldfroms:
+ newfroms.add(new)
+ oldfroms.remove(old)
+ self.__correlate = self.__correlate.union(newcorrelate)
+ self._froms = [f for f in oldfroms.union(newfroms)]
+
+ def column(self, column):
+ s = self._generate()
+ s.append_column(column)
+ return s
+
+ def where(self, whereclause):
+ s = self._generate()
+ s.append_whereclause(whereclause)
+ return s
+
+ def having(self, having):
+ s = self._generate()
+ s.append_having(having)
+ return s
+
+ def distinct(self):
+ s = self._generate()
+ s.distinct = True
+ return s
+
+ def prefix_with(self, clause):
+ s = self._generate()
+ s.append_prefix(clause)
+ return s
+
+ def select_from(self, fromclause):
+ s = self._generate()
+ s.append_from(fromclause)
+ return s
+
+ def __dont_correlate(self):
+ s = self._generate()
+ s._should_correlate = False
+ return s
+
+ def correlate(self, fromclause):
+ s = self._generate()
+ s._should_correlate=False
+ if fromclause is None:
+ s.__correlate = util.Set()
+ else:
+ s.append_correlation(fromclause)
+ return s
+
+ def append_correlation(self, fromclause):
+ self.__correlate.add(fromclause)
+
def append_column(self, column):
- if _is_literal(column):
- column = literal_column(str(column))
+ column = _literal_as_column(column)
- if isinstance(column, Select) and column.is_scalar:
- column = column.self_group(against=',')
+ if isinstance(column, _ScalarSelect):
+ column = column.self_group(against=ColumnOperators.comma_op)
self._raw_columns.append(column)
-
- if self.is_scalar and not hasattr(self, 'type'):
- self.type = column.type
+
+ def append_prefix(self, clause):
+ clause = _literal_as_text(clause)
+ self._prefixes.append(clause)
- # if the column is a Select statement itself,
- # accept visitor
- self.__correlator.traverse(column)
-
- # visit the FROM objects of the column looking for more Selects
- for f in column._get_from_objects():
- if f is not self:
- self.__correlator.traverse(f)
- self._process_froms(column, False)
-
- def _make_proxy(self, selectable, name):
- if self.is_scalar:
- return self._raw_columns[0]._make_proxy(selectable, name)
+ def append_whereclause(self, whereclause):
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- raise exceptions.InvalidRequestError("Not a scalar select statement")
-
- def label(self, name):
- if not self.is_scalar:
- raise exceptions.InvalidRequestError("Not a scalar select statement")
+ self._whereclause = _literal_as_text(whereclause)
+
+ def append_having(self, having):
+ if self._having is not None:
+ self._having = and_(self._having, _literal_as_text(having))
else:
- return label(name, self)
+ self._having = _literal_as_text(having)
+
+ def append_from(self, fromclause):
+ if _is_literal(fromclause):
+ fromclause = FromClause(fromclause)
+ self._froms.add(fromclause)
def _exportable_columns(self):
return [c for c in self._raw_columns if isinstance(c, Selectable)]
@@ -3019,53 +3248,13 @@ class Select(_SelectBaseMixin, FromClause):
else:
return column._make_proxy(self)
- def _process_froms(self, elem, asfrom):
- for f in elem._get_from_objects():
- self.__fromvisitor.traverse(f)
- self.__froms.add(f)
- if asfrom:
- self.__froms.add(elem)
- for f in elem._hide_froms():
- self.__hide_froms.add(f)
-
def self_group(self, against=None):
if isinstance(against, CompoundSelect):
return self
return _Grouping(self)
-
- def append_whereclause(self, whereclause):
- self._append_condition('whereclause', whereclause)
-
- def append_having(self, having):
- self._append_condition('having', having)
-
- def _append_condition(self, attribute, condition):
- if isinstance(condition, basestring):
- condition = _TextClause(condition)
- self.__wherecorrelator.traverse(condition)
- self._process_froms(condition, False)
- if getattr(self, attribute) is not None:
- setattr(self, attribute, and_(getattr(self, attribute), condition))
- else:
- setattr(self, attribute, condition)
-
- def correlate(self, from_obj):
- """Given a ``FROM`` object, correlate this ``SELECT`` statement to it.
-
- This basically means the given from object will not come out
- in this select statement's ``FROM`` clause when printed.
- """
-
- self.__correlated[from_obj] = from_obj
-
- def append_from(self, fromclause):
- if isinstance(fromclause, basestring):
- fromclause = FromClause(fromclause)
- self.__correlator.traverse(fromclause)
- self._process_froms(fromclause, True)
def _locate_oid_column(self):
- for f in self.__froms:
+ for f in self.locate_all_froms():
if f is self:
# we might be in our own _froms list if a column with us as the parent is attached,
# which includes textual columns.
@@ -3076,25 +3265,6 @@ class Select(_SelectBaseMixin, FromClause):
else:
return None
- def _calc_froms(self):
- f = self.__froms.difference(self.__hide_froms)
- if (len(f) > 1):
- return f.difference(self.__correlated)
- else:
- return f
-
- froms = property(_calc_froms,
- doc="""A collection containing all elements
- of the ``FROM`` clause.""")
-
- def get_children(self, column_collections=True, **kwargs):
- return (column_collections and list(self.columns) or []) + \
- list(self.froms) + \
- [x for x in (self.whereclause, self.having, self.order_by_clause, self.group_by_clause) if x is not None]
-
- def accept_visitor(self, visitor):
- visitor.visit_select(self)
-
def union(self, other, **kwargs):
return union(self, other, **kwargs)
@@ -3108,7 +3278,7 @@ class Select(_SelectBaseMixin, FromClause):
if self._bind is not None:
return self._bind
- for f in self.__froms:
+ for f in self._froms:
if f is self:
continue
e = f.bind
@@ -3133,20 +3303,24 @@ class _UpdateBase(ClauseElement):
def supports_execution(self):
return True
- class _SelectCorrelator(NoColumnVisitor):
- def __init__(self, table):
- NoColumnVisitor.__init__(self)
- self.table = table
-
- def visit_select(self, select):
- if select.should_correlate:
- select.correlate(self.table)
-
- def _process_whereclause(self, whereclause):
- if whereclause is not None:
- _UpdateBase._SelectCorrelator(self.table).traverse(whereclause)
- return whereclause
+ def _calculate_correlations(self, correlate_state):
+ class SelectCorrelator(NoColumnVisitor):
+ def visit_select(s, select):
+ if select._should_correlate:
+ select_state = correlate_state.setdefault(select, {})
+ corr = select_state.setdefault('correlate', util.Set())
+ corr.add(self.table)
+
+ vis = SelectCorrelator()
+ if self._whereclause is not None:
+ vis.traverse(self._whereclause)
+
+ if getattr(self, 'parameters', None) is not None:
+ for key, value in self.parameters.items():
+ if isinstance(value, ClauseElement):
+ vis.traverse(value)
+
def _process_colparams(self, parameters):
"""Receive the *values* of an ``INSERT`` or ``UPDATE``
statement and construct appropriate bind parameters.
@@ -3155,7 +3329,7 @@ class _UpdateBase(ClauseElement):
if parameters is None:
return None
- if isinstance(parameters, list) or isinstance(parameters, tuple):
+ if isinstance(parameters, (list, tuple)):
pp = {}
i = 0
for c in self.table.c:
@@ -3163,11 +3337,10 @@ class _UpdateBase(ClauseElement):
i +=1
parameters = pp
- correlator = _UpdateBase._SelectCorrelator(self.table)
for key in parameters.keys():
value = parameters[key]
if isinstance(value, ClauseElement):
- correlator.traverse(value)
+ parameters[key] = value.self_group()
elif _is_literal(value):
if _is_literal(key):
col = self.table.c[key]
@@ -3182,7 +3355,7 @@ class _UpdateBase(ClauseElement):
def _find_engine(self):
return self.table.bind
-class _Insert(_UpdateBase):
+class Insert(_UpdateBase):
def __init__(self, table, values=None):
self.table = table
self.select = None
@@ -3193,32 +3366,41 @@ class _Insert(_UpdateBase):
return self.select,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_insert(self)
-class _Update(_UpdateBase):
+class Update(_UpdateBase):
def __init__(self, table, whereclause, values=None):
self.table = table
- self.whereclause = self._process_whereclause(whereclause)
+ self._whereclause = whereclause
self.parameters = self._process_colparams(values)
def get_children(self, **kwargs):
- if self.whereclause is not None:
- return self.whereclause,
+ if self._whereclause is not None:
+ return self._whereclause,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_update(self)
-class _Delete(_UpdateBase):
+class Delete(_UpdateBase):
def __init__(self, table, whereclause):
self.table = table
- self.whereclause = self._process_whereclause(whereclause)
+ self._whereclause = whereclause
def get_children(self, **kwargs):
- if self.whereclause is not None:
- return self.whereclause,
+ if self._whereclause is not None:
+ return self._whereclause,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_delete(self)
+
+class _IdentifiedClause(ClauseElement):
+ def __init__(self, ident):
+ self.ident = ident
+ def supports_execution(self):
+ return True
+
+class SavepointClause(_IdentifiedClause):
+ pass
+
+class RollbackToSavepointClause(_IdentifiedClause):
+ pass
+
+class ReleaseSavepointClause(_IdentifiedClause):
+ pass