summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/expression.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r--lib/sqlalchemy/sql/expression.py490
1 files changed, 219 insertions, 271 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 867fdd69c..7ce637701 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -26,12 +26,12 @@ to stay the same in future releases.
"""
import itertools, re
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
functions, schema, sql_util = None, None, None
-DefaultDialect, ClauseAdapter = None, None
+DefaultDialect, ClauseAdapter, Annotated = None, None, None
__all__ = [
'Alias', 'ClauseElement',
@@ -503,15 +503,21 @@ def collate(expression, collation):
def exists(*args, **kwargs):
"""Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object.
+
+ Calling styles are of the following forms::
+
+ # use on an existing select()
+ s = select([<columns>]).where(<criterion>)
+ s = exists(s)
+
+ # construct a select() at once
+ exists(['*'], **select_arguments).where(<criterion>)
+
+ # columns argument is optional, generates "EXISTS (SELECT *)"
+ # by default.
+ exists().where(<criterion>)
- The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by
- itself or used as a subquery within an enclosing select.
-
- \*args, \**kwargs
- all arguments are sent directly to the [sqlalchemy.sql.expression#select()]
- function to produce a ``SELECT`` statement.
"""
-
return _Exists(*args, **kwargs)
def union(*selects, **kwargs):
@@ -872,27 +878,36 @@ def _compound_select(keyword, *selects, **kwargs):
return CompoundSelect(keyword, *selects, **kwargs)
def _is_literal(element):
- return not isinstance(element, ClauseElement)
+ return not isinstance(element, (ClauseElement, Operators))
+
+def _from_objects(*elements, **kwargs):
+ return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
+def _labeled(element):
+ if not hasattr(element, 'name'):
+ return element.label(None)
+ else:
+ return element
+
def _literal_as_text(element):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
return _TextClause(unicode(element))
else:
return element
def _literal_as_column(element):
- if isinstance(element, Operators):
- return element.clause_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
return literal_column(str(element))
else:
return element
def _literal_as_binds(element, name=None, type_=None):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
if element is None:
return null()
@@ -902,17 +917,17 @@ def _literal_as_binds(element, name=None, type_=None):
return element
def _no_literals(element):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
- raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+ raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
else:
return element
def _corresponding_column_or_error(fromclause, column, require_embedded=False):
c = fromclause.corresponding_column(column, require_embedded=require_embedded)
if not c:
- raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+ raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
return c
def _selectable(element):
@@ -921,9 +936,8 @@ def _selectable(element):
elif isinstance(element, Selectable):
return element
else:
- raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
+ raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
-
def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
return isinstance(col, ColumnElement)
@@ -941,7 +955,9 @@ class _FigureVisitName(type):
class ClauseElement(object):
"""Base class for elements of a programmatically constructed SQL expression."""
__metaclass__ = _FigureVisitName
-
+ _annotations = {}
+ supports_execution = False
+
def _clone(self):
"""Create a shallow copy of this ClauseElement.
@@ -976,6 +992,14 @@ class ClauseElement(object):
"""
raise NotImplementedError(repr(self))
+
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with the given annotations dictionary."""
+
+ global Annotated
+ if Annotated is None:
+ from sqlalchemy.sql.util import Annotated
+ return Annotated(self, values)
def unique_params(self, *optionaldict, **kwargs):
"""Return a copy with ``bindparam()`` elments replaced.
@@ -1006,14 +1030,14 @@ class ClauseElement(object):
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
- raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument")
+ raise exc.ArgumentError("params() takes zero or one positional dictionary argument")
def visit_bindparam(bind):
if bind.key in kwargs:
bind.value = kwargs[bind.key]
if unique:
bind._convert_to_unique()
- return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True)
+ return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam})
def compare(self, other):
"""Compare this ClauseElement to the given ClauseElement.
@@ -1049,11 +1073,6 @@ class ClauseElement(object):
def self_group(self, against=None):
return self
- def supports_execution(self):
- """Return True if this clause element represents a complete executable statement."""
-
- return False
-
def bind(self):
"""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
@@ -1062,7 +1081,7 @@ class ClauseElement(object):
return self._bind
except AttributeError:
pass
- for f in self._get_from_objects():
+ for f in _from_objects(self):
if f is self:
continue
engine = f.bind
@@ -1083,7 +1102,7 @@ class ClauseElement(object):
'Engine for execution. Or, assign a bind to the statement '
'or the Metadata of its underlying tables to enable '
'implicit execution via this method.' % label)
- raise exceptions.UnboundExecutionError(msg)
+ raise exc.UnboundExecutionError(msg)
return e.execute_clauseelement(self, multiparams, params)
def scalar(self, *multiparams, **params):
@@ -1159,6 +1178,12 @@ class ClauseElement(object):
self.__module__, self.__class__.__name__, id(self), friendly)
+class _Immutable(object):
+ """mark a ClauseElement as 'immutable' when expressions are cloned."""
+
+ def _clone(self):
+ return self
+
class Operators(object):
def __and__(self, other):
return self.operate(operators.and_, other)
@@ -1174,9 +1199,6 @@ class Operators(object):
return self.operate(operators.op, opstring, b)
return op
- def clause_element(self):
- raise NotImplementedError()
-
def operate(self, op, *other, **kwargs):
raise NotImplementedError()
@@ -1216,7 +1238,7 @@ class ColumnOperators(Operators):
def ilike(self, other, escape=None):
return self.operate(operators.ilike_op, other, escape=escape)
- def in_(self, *other):
+ def in_(self, other):
return self.operate(operators.in_op, other)
def startswith(self, other, **kwargs):
@@ -1279,18 +1301,18 @@ class _CompareMixin(ColumnOperators):
def __compare(self, op, obj, negate=None, reverse=False, **kwargs):
if obj is None or isinstance(obj, _Null):
if op == operators.eq:
- return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot)
+ return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot)
elif op == operators.ne:
- return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_)
+ return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_)
else:
- raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+ raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL")
else:
obj = self._check_literal(obj)
if reverse:
- return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+ return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
else:
- return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+ return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
def __operate(self, op, obj, reverse=False):
obj = self._check_literal(obj)
@@ -1298,9 +1320,9 @@ class _CompareMixin(ColumnOperators):
type_ = self._compare_type(obj)
if reverse:
- return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_)
+ return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_)
else:
- return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_)
+ return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_)
# a mapping of operators with the method they use, along with their negated
# operator for comparison operators
@@ -1329,17 +1351,10 @@ class _CompareMixin(ColumnOperators):
o = _CompareMixin.operators[op]
return o[0](self, op, other, reverse=True, *o[1:], **kwargs)
- def in_(self, *other):
- return self._in_impl(operators.in_op, operators.notin_op, *other)
-
- def _in_impl(self, op, negate_op, *other):
- # Handle old style *args argument passing
- if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)):
- util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable')
- seq_or_selectable = other
- else:
- seq_or_selectable = other[0]
+ def in_(self, other):
+ return self._in_impl(operators.in_op, operators.notin_op, other)
+ def _in_impl(self, op, negate_op, seq_or_selectable):
if isinstance(seq_or_selectable, Selectable):
return self.__compare( op, seq_or_selectable, negate=negate_op)
@@ -1348,7 +1363,7 @@ class _CompareMixin(ColumnOperators):
for o in seq_or_selectable:
if not _is_literal(o):
if not isinstance( o, _CompareMixin):
- raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
+ raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
else:
o = self._bind_param(o)
args.append(o)
@@ -1433,22 +1448,13 @@ class _CompareMixin(ColumnOperators):
if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
other.type = self.type
return other
- elif isinstance(other, Operators):
- return other.expression_element()
+ elif hasattr(other, '__clause_element__'):
+ return other.__clause_element__()
elif _is_literal(other):
return self._bind_param(other)
else:
return other
- def clause_element(self):
- """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``."""
- return self
-
- def expression_element(self):
- """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
-
- return self
-
def _compare_type(self, obj):
"""Allow subclasses to override the type used in constructing
``_BinaryExpression`` objects.
@@ -1480,23 +1486,22 @@ class ColumnElement(ClauseElement, _CompareMixin):
primary_key = False
foreign_keys = []
-
+ quote = None
+
def base_columns(self):
- if hasattr(self, '_base_columns'):
- return self._base_columns
- self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
+ if not hasattr(self, '_base_columns'):
+ self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
return self._base_columns
base_columns = property(base_columns)
def proxy_set(self):
- if hasattr(self, '_proxy_set'):
- return self._proxy_set
- s = util.Set([self])
- if hasattr(self, 'proxies'):
- for c in self.proxies:
- s = s.union(c.proxy_set)
- self._proxy_set = s
- return s
+ if not hasattr(self, '_proxy_set'):
+ s = util.Set([self])
+ if hasattr(self, 'proxies'):
+ for c in self.proxies:
+ s.update(c.proxy_set)
+ self._proxy_set = s
+ return self._proxy_set
proxy_set = property(proxy_set)
def shares_lineage(self, othercolumn):
@@ -1518,7 +1523,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None))
co.proxies = [self]
- selectable.columns[name]= co
+ selectable.columns[name] = co
return co
def anon_label(self):
@@ -1613,7 +1618,7 @@ class ColumnCollection(util.OrderedProperties):
def __contains__(self, other):
if not isinstance(other, basestring):
- raise exceptions.ArgumentError("__contains__ requires a string argument")
+ raise exc.ArgumentError("__contains__ requires a string argument")
return util.OrderedProperties.__contains__(self, other)
def contains_column(self, col):
@@ -1641,6 +1646,9 @@ class ColumnSet(util.OrderedSet):
l.append(c==local)
return and_(*l)
+ def __hash__(self):
+ return hash(tuple(self._list))
+
class Selectable(ClauseElement):
"""mark a class as being selectable"""
@@ -1648,8 +1656,9 @@ class FromClause(Selectable):
"""Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
__visit_name__ = 'fromclause'
- named_with_column=False
+ named_with_column = False
_hide_froms = []
+ quote = None
def _get_from_objects(self, **modifiers):
return []
@@ -1694,12 +1703,12 @@ class FromClause(Selectable):
return fromclause in util.Set(self._cloned_set)
def replace_selectable(self, old, alias):
- """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
+ """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
- global ClauseAdapter
- if ClauseAdapter is None:
- from sqlalchemy.sql.util import ClauseAdapter
- return ClauseAdapter(alias).traverse(self, clone=True)
+ global ClauseAdapter
+ if ClauseAdapter is None:
+ from sqlalchemy.sql.util import ClauseAdapter
+ return ClauseAdapter(alias).traverse(self)
def correspond_on_equivalents(self, column, equivalents):
col = self.corresponding_column(column, require_embedded=True)
@@ -1859,7 +1868,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
def _convert_to_unique(self):
if not self.unique:
- self.unique=True
+ self.unique = True
self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param')
def _get_from_objects(self, **modifiers):
@@ -1910,6 +1919,7 @@ class _TextClause(ClauseElement):
__visit_name__ = 'textclause'
_bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+ supports_execution = True
_hide_froms = []
oid_column = None
@@ -1950,12 +1960,6 @@ class _TextClause(ClauseElement):
def _get_from_objects(self, **modifiers):
return []
- def supports_execution(self):
- return True
-
- def _table_iterator(self):
- return iter([])
-
class _Null(ColumnElement):
"""Represent the NULL keyword in a SQL statement.
@@ -2042,6 +2046,7 @@ class _CalculatedClause(ColumnElement):
__visit_name__ = 'calculatedclause'
def __init__(self, name, *clauses, **kwargs):
+ ColumnElement.__init__(self)
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type_', None))
self._bind = kwargs.get('bind', None)
@@ -2061,7 +2066,7 @@ class _CalculatedClause(ColumnElement):
def clauses(self):
if isinstance(self.clause_expr, _Grouping):
- return self.clause_expr.elem
+ return self.clause_expr.element
else:
return self.clause_expr
clauses = property(clauses)
@@ -2239,8 +2244,13 @@ class _Exists(_UnaryExpression):
__visit_name__ = _UnaryExpression.__visit_name__
def __init__(self, *args, **kwargs):
- kwargs['correlate'] = True
- s = select(*args, **kwargs).as_scalar().self_group()
+ if args and isinstance(args[0], _SelectBaseMixin):
+ s = args[0]
+ else:
+ if not args:
+ args = ([literal_column('*')],)
+ s = select(*args, **kwargs).as_scalar().self_group()
+
_UnaryExpression.__init__(self, s, operator=operators.exists)
def select(self, whereclause=None, **params):
@@ -2272,7 +2282,7 @@ class Join(FromClause):
self.right = _selectable(right).self_group()
if onclause is None:
- self.onclause = self.__match_primaries(self.left, self.right)
+ self.onclause = self._match_primaries(self.left, self.right)
else:
self.onclause = onclause
@@ -2310,7 +2320,7 @@ class Join(FromClause):
def get_children(self, **kwargs):
return self.left, self.right, self.onclause
- def __match_primaries(self, primary, secondary):
+ def _match_primaries(self, primary, secondary):
global sql_util
if not sql_util:
from sqlalchemy.sql import util as sql_util
@@ -2359,7 +2369,7 @@ class Join(FromClause):
return self.select(use_labels=True, correlate=False).alias(name)
def _hide_froms(self):
- return itertools.chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set])
+ return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set])
_hide_froms = property(_hide_froms)
def _get_from_objects(self, **modifiers):
@@ -2382,9 +2392,10 @@ class Alias(FromClause):
def __init__(self, selectable, alias=None):
baseselectable = selectable
while isinstance(baseselectable, Alias):
- baseselectable = baseselectable.selectable
+ baseselectable = baseselectable.element
self.original = baseselectable
- self.selectable = selectable
+ self.supports_execution = baseselectable.supports_execution
+ self.element = selectable
if alias is None:
if self.original.named_with_column:
alias = getattr(self.original, 'name', None)
@@ -2398,112 +2409,100 @@ class Alias(FromClause):
def is_derived_from(self, fromclause):
if fromclause in util.Set(self._cloned_set):
return True
- return self.selectable.is_derived_from(fromclause)
-
- def supports_execution(self):
- return self.original.supports_execution()
-
- def _table_iterator(self):
- return self.original._table_iterator()
+ return self.element.is_derived_from(fromclause)
def _populate_column_collection(self):
- for col in self.selectable.columns:
+ for col in self.element.columns:
col._make_proxy(self)
- if self.selectable.oid_column is not None:
- self._oid_column = self.selectable.oid_column._make_proxy(self)
+ if self.element.oid_column is not None:
+ self._oid_column = self.element.oid_column._make_proxy(self)
def _copy_internals(self, clone=_clone):
- self._reset_exported()
- self.selectable = _clone(self.selectable)
- baseselectable = self.selectable
- while isinstance(baseselectable, Alias):
- baseselectable = baseselectable.selectable
- self.original = baseselectable
+ self._reset_exported()
+ self.element = _clone(self.element)
+ baseselectable = self.element
+ while isinstance(baseselectable, Alias):
+ baseselectable = baseselectable.selectable
+ self.original = baseselectable
def get_children(self, column_collections=True, aliased_selectables=True, **kwargs):
if column_collections:
for c in self.c:
yield c
if aliased_selectables:
- yield self.selectable
+ yield self.element
def _get_from_objects(self, **modifiers):
return [self]
def bind(self):
- return self.selectable.bind
+ return self.element.bind
bind = property(bind)
-class _ColumnElementAdapter(ColumnElement):
- """Adapts a ClauseElement which may or may not be a
- ColumnElement subclass itself into an object which
- acts like a ColumnElement.
- """
+class _Grouping(ColumnElement):
+ """Represent a grouping within a column expression"""
- def __init__(self, elem):
- self.elem = elem
- self.type = getattr(elem, 'type', None)
+ def __init__(self, element):
+ ColumnElement.__init__(self)
+ self.element = element
+ self.type = getattr(element, 'type', None)
def key(self):
- return self.elem.key
+ return self.element.key
key = property(key)
def _label(self):
try:
- return self.elem._label
+ return self.element._label
except AttributeError:
return self.anon_label
_label = property(_label)
def _copy_internals(self, clone=_clone):
- self.elem = clone(self.elem)
+ self.element = clone(self.element)
def get_children(self, **kwargs):
- return self.elem,
+ return self.element,
def _get_from_objects(self, **modifiers):
- return self.elem._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def __getattr__(self, attr):
- return getattr(self.elem, attr)
+ return getattr(self.element, attr)
def __getstate__(self):
- return {'elem':self.elem, 'type':self.type}
+ return {'element':self.element, 'type':self.type}
def __setstate__(self, state):
- self.elem = state['elem']
+ self.element = state['element']
self.type = state['type']
-class _Grouping(_ColumnElementAdapter):
- """Represent a grouping within a column expression"""
- pass
-
class _FromGrouping(FromClause):
"""Represent a grouping of a FROM clause"""
__visit_name__ = 'grouping'
- def __init__(self, elem):
- self.elem = elem
+ def __init__(self, element):
+ self.element = element
def columns(self):
- return self.elem.columns
+ return self.element.columns
columns = c = property(columns)
def _hide_froms(self):
- return self.elem._hide_froms
+ return self.element._hide_froms
_hide_froms = property(_hide_froms)
def get_children(self, **kwargs):
- return self.elem,
+ return self.element,
def _copy_internals(self, clone=_clone):
- self.elem = clone(self.elem)
+ self.element = clone(self.element)
def _get_from_objects(self, **modifiers):
- return self.elem._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def __getattr__(self, attr):
- return getattr(self.elem, attr)
+ return getattr(self.element, attr)
class _Label(ColumnElement):
"""Represents a column label (AS).
@@ -2516,12 +2515,12 @@ class _Label(ColumnElement):
``ColumnElement`` subclasses.
"""
- def __init__(self, name, obj, type_=None):
- while isinstance(obj, _Label):
- obj = obj.obj
- self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
- self.obj = obj.self_group(against=operators.as_)
- self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
+ def __init__(self, name, element, type_=None):
+ while isinstance(element, _Label):
+ element = element.element
+ self.name = name or "{ANON %d %s}" % (id(self), getattr(element, 'name', 'anon'))
+ self.element = element.self_group(against=operators.as_)
+ self.type = sqltypes.to_instance(type_ or getattr(element, 'type', None))
def key(self):
return self.name
@@ -2532,8 +2531,9 @@ class _Label(ColumnElement):
_label = property(_label)
def _proxy_attr(name):
+ get = util.attrgetter(name)
def attr(self):
- return getattr(self.obj, name)
+ return get(self.element)
return property(attr)
proxies = _proxy_attr('proxies')
@@ -2542,27 +2542,24 @@ class _Label(ColumnElement):
primary_key = _proxy_attr('primary_key')
foreign_keys = _proxy_attr('foreign_keys')
- def expression_element(self):
- return self.obj
-
def get_children(self, **kwargs):
- return self.obj,
+ return self.element,
def _copy_internals(self, clone=_clone):
- self.obj = clone(self.obj)
+ self.element = clone(self.element)
def _get_from_objects(self, **modifiers):
- return self.obj._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def _make_proxy(self, selectable, name = None):
- if isinstance(self.obj, (Selectable, ColumnElement)):
- e = self.obj._make_proxy(selectable, name=self.name)
+ if isinstance(self.element, (Selectable, ColumnElement)):
+ e = self.element._make_proxy(selectable, name=self.name)
else:
e = column(self.name)._make_proxy(selectable=selectable)
e.proxies.append(self)
return e
-class _ColumnClause(ColumnElement):
+class _ColumnClause(_Immutable, ColumnElement):
"""Represents a generic column expression from any textual string.
This includes columns associated with tables, aliases and select
@@ -2602,16 +2599,7 @@ class _ColumnClause(ColumnElement):
return self.name.encode('ascii', 'backslashreplace')
description = property(description)
- def _clone(self):
- # ColumnClause is immutable
- return self
-
def _label(self):
- """Generate a 'label' string for this column.
- """
-
- # for a "literal" column, we've no idea what the text is
- # therefore no 'label' can be automatically generated
if self.is_literal:
return None
if not self.__label:
@@ -2626,24 +2614,21 @@ class _ColumnClause(ColumnElement):
counter = 1
while label in self.table.c:
label = self.__label + "_" + str(counter)
- counter +=1
+ counter += 1
self.__label = label
else:
self.__label = self.name
return self.__label
-
_label = property(_label)
def label(self, name):
- # if going off the "__label" property and its None, we have
- # no label; return self
if name is None:
return self
else:
return super(_ColumnClause, self).label(name)
def _get_from_objects(self, **modifiers):
- if self.table is not None:
+ if self.table:
return [self.table]
else:
return []
@@ -2651,20 +2636,20 @@ class _ColumnClause(ColumnElement):
def _bind_param(self, obj):
return _BindParamClause(self.name, obj, type_=self.type, unique=True)
- def _make_proxy(self, selectable, name = None):
+ def _make_proxy(self, selectable, name=None, attach=True):
# propigate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
c.proxies = [self]
- if not self._is_oid:
+ if attach and not self._is_oid:
selectable.columns[c.name] = c
return c
def _compare_type(self, obj):
return self.type
-class TableClause(FromClause):
+class TableClause(_Immutable, FromClause):
"""Represents a "table" construct.
Note that this represents tables only as another syntactical
@@ -2691,10 +2676,6 @@ class TableClause(FromClause):
return self.name.encode('ascii', 'backslashreplace')
description = property(description)
- def _clone(self):
- # TableClause is immutable
- return self
-
def append_column(self, c):
self._columns[c.name] = c
c.table = self
@@ -2724,10 +2705,11 @@ class TableClause(FromClause):
def _get_from_objects(self, **modifiers):
return [self]
-
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
+ supports_execution = True
+
def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, autocommit=False):
self.use_labels = use_labels
self.for_update = for_update
@@ -2773,11 +2755,6 @@ class _SelectBaseMixin(object):
"""
return self.as_scalar().label(name)
- def supports_execution(self):
- """part of the ClauseElement contract; returns ``True`` in all cases for this class."""
-
- return True
-
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to True."""
@@ -2860,15 +2837,15 @@ class _SelectBaseMixin(object):
class _ScalarSelect(_Grouping):
__visit_name__ = 'grouping'
- def __init__(self, elem):
- self.elem = elem
- cols = list(elem.inner_columns)
+ def __init__(self, element):
+ self.element = element
+ cols = list(element.inner_columns)
if len(cols) != 1:
- raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
+ raise exc.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
self.type = cols[0].type
def columns(self):
- raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
+ raise exc.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
columns = c = property(columns)
def self_group(self, **kwargs):
@@ -2893,7 +2870,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
if not numcols:
numcols = len(s.c)
elif len(s.c) != numcols:
- raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
+ raise exc.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
(1, len(self.selects[0].c), n+1, len(s.c))
)
if s._order_by_clause:
@@ -2936,11 +2913,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
return (column_collections and list(self.c) or []) + \
[self._order_by_clause, self._group_by_clause] + list(self.selects)
- def _table_iterator(self):
- for s in self.selects:
- for t in s._table_iterator():
- yield t
-
def bind(self):
if self._bind:
return self._bind
@@ -2976,6 +2948,7 @@ class Select(_SelectBaseMixin, FromClause):
self._distinct = distinct
self._correlate = util.Set()
+ self._froms = util.OrderedSet()
if columns:
self._raw_columns = [
@@ -2983,22 +2956,23 @@ class Select(_SelectBaseMixin, FromClause):
for c in
[_literal_as_column(c) for c in columns]
]
+
+ self._froms.update(_from_objects(*self._raw_columns))
else:
self._raw_columns = []
-
- if from_obj:
- self._froms = util.Set([
- _is_literal(f) and _TextClause(f) or f
- for f in util.to_list(from_obj)
- ])
- else:
- self._froms = util.Set()
-
+
if whereclause:
self._whereclause = _literal_as_text(whereclause)
+ self._froms.update(_from_objects(self._whereclause, is_where=True))
else:
self._whereclause = None
+ if from_obj:
+ self._froms.update([
+ _is_literal(f) and _TextClause(f) or f
+ for f in util.to_list(from_obj)
+ ])
+
if having:
self._having = _literal_as_text(having)
else:
@@ -3020,36 +2994,28 @@ class Select(_SelectBaseMixin, FromClause):
correlating.
"""
- froms = util.OrderedSet()
-
- for col in self._raw_columns:
- froms.update(col._get_from_objects())
-
- if self._whereclause is not None:
- froms.update(self._whereclause._get_from_objects(is_where=True))
-
- if self._froms:
- froms.update(self._froms)
+ froms = self._froms
toremove = itertools.chain(*[f._hide_froms for f in froms])
- froms.difference_update(toremove)
+ if toremove:
+ froms = froms.difference(toremove)
if len(froms) > 1 or self._correlate:
if self._correlate:
- froms.difference_update(_cloned_intersection(froms, self._correlate))
+ froms = froms.difference(_cloned_intersection(froms, self._correlate))
if self._should_correlate and existing_froms:
- froms.difference_update(_cloned_intersection(froms, existing_froms))
+ froms = froms.difference(_cloned_intersection(froms, existing_froms))
if not len(froms):
- raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self)
+ raise exc.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self)
return froms
froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
def type(self):
- raise exceptions.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.")
+ raise exc.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.")
type = property(type)
def locate_all_froms(self):
@@ -3059,22 +3025,10 @@ class Select(_SelectBaseMixin, FromClause):
is specifically for those FromClause elements that would actually be rendered.
"""
- if hasattr(self, '_all_froms'):
- return self._all_froms
-
- froms = util.Set(
- itertools.chain(*
- [self._froms] +
- [f._get_from_objects() for f in self._froms] +
- [col._get_from_objects() for col in self._raw_columns]
- )
- )
+ if not hasattr(self, '_all_froms'):
+ self._all_froms = self._froms.union(_from_objects(*list(self._froms)))
- if self._whereclause:
- froms.update(self._whereclause._get_from_objects(is_where=True))
-
- self._all_froms = froms
- return froms
+ return self._all_froms
def inner_columns(self):
"""an iteratorof all ColumnElement expressions which would
@@ -3092,7 +3046,7 @@ class Select(_SelectBaseMixin, FromClause):
def is_derived_from(self, fromclause):
if self in util.Set(fromclause._cloned_set):
return True
-
+
for f in self.locate_all_froms():
if f.is_derived_from(fromclause):
return True
@@ -3112,7 +3066,7 @@ class Select(_SelectBaseMixin, FromClause):
"""return child elements as per the ClauseElement specification."""
return (column_collections and list(self.columns) or []) + \
- list(self.locate_all_froms()) + \
+ list(self._froms) + \
[x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
def column(self, column):
@@ -3125,6 +3079,7 @@ class Select(_SelectBaseMixin, FromClause):
column = column.self_group(against=operators.comma_op)
s._raw_columns = s._raw_columns + [column]
+ s._froms = s._froms.union(_from_objects(column))
return s
def where(self, whereclause):
@@ -3185,7 +3140,7 @@ class Select(_SelectBaseMixin, FromClause):
"""
s = self._generate()
- s._should_correlate=False
+ s._should_correlate = False
if fromclauses == (None,):
s._correlate = util.Set()
else:
@@ -3195,7 +3150,7 @@ class Select(_SelectBaseMixin, FromClause):
def append_correlation(self, fromclause):
"""append the given correlation expression to this select() construct."""
- self._should_correlate=False
+ self._should_correlate = False
self._correlate = self._correlate.union([fromclause])
def append_column(self, column):
@@ -3207,6 +3162,7 @@ class Select(_SelectBaseMixin, FromClause):
column = column.self_group(against=operators.comma_op)
self._raw_columns = self._raw_columns + [column]
+ self._froms = self._froms.union(_from_objects(column))
self._reset_exported()
def append_prefix(self, clause):
@@ -3221,10 +3177,13 @@ class Select(_SelectBaseMixin, FromClause):
The expression will be joined to existing WHERE criterion via AND.
"""
+ whereclause = _literal_as_text(whereclause)
+ self._froms = self._froms.union(_from_objects(whereclause, is_where=True))
+
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
+ self._whereclause = and_(self._whereclause, whereclause)
else:
- self._whereclause = _literal_as_text(whereclause)
+ self._whereclause = whereclause
def append_having(self, having):
"""append the given expression to this select() construct's HAVING criterion.
@@ -3311,31 +3270,23 @@ class Select(_SelectBaseMixin, FromClause):
return intersect_all(self, other, **kwargs)
- def _table_iterator(self):
- for t in visitors.NoColumnVisitor().iterate(self):
- if isinstance(t, TableClause):
- yield t
-
def bind(self):
if self._bind:
return self._bind
- for f in self._froms:
- if f is self:
- continue
- e = f.bind
- if e:
- self._bind = e
- return e
- # look through the columns (largely synomous with looking
- # through the FROMs except in the case of _CalculatedClause/_Function)
- for c in self._raw_columns:
- if getattr(c, 'table', None) is self:
- continue
- e = c.bind
+ if not self._froms:
+ for c in self._raw_columns:
+ e = c.bind
+ if e:
+ self._bind = e
+ return e
+ else:
+ e = list(self._froms)[0].bind
if e:
self._bind = e
return e
+
return None
+
def _set_bind(self, bind):
self._bind = bind
bind = property(bind, _set_bind)
@@ -3343,11 +3294,7 @@ class Select(_SelectBaseMixin, FromClause):
class _UpdateBase(ClauseElement):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
- def supports_execution(self):
- return True
-
- def _table_iterator(self):
- return iter([self.table])
+ supports_execution = True
def _generate(self):
s = self.__class__.__new__(self.__class__)
@@ -3407,7 +3354,7 @@ class Insert(_ValuesBase):
self._bind = bind
self.table = table
self.select = None
- self.inline=inline
+ self.inline = inline
if prefixes:
self._prefixes = [_literal_as_text(p) for p in prefixes]
else:
@@ -3502,10 +3449,11 @@ class Delete(_UpdateBase):
self._whereclause = clone(self._whereclause)
class _IdentifiedClause(ClauseElement):
+ supports_execution = True
+ quote = None
+
def __init__(self, ident):
self.ident = ident
- def supports_execution(self):
- return True
class SavepointClause(_IdentifiedClause):
pass