diff options
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 490 |
1 files changed, 219 insertions, 271 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 867fdd69c..7ce637701 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -26,12 +26,12 @@ to stay the same in future releases. """ import itertools, re -from sqlalchemy import util, exceptions +from sqlalchemy import util, exc from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes functions, schema, sql_util = None, None, None -DefaultDialect, ClauseAdapter = None, None +DefaultDialect, ClauseAdapter, Annotated = None, None, None __all__ = [ 'Alias', 'ClauseElement', @@ -503,15 +503,21 @@ def collate(expression, collation): def exists(*args, **kwargs): """Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object. + + Calling styles are of the following forms:: + + # use on an existing select() + s = select([<columns>]).where(<criterion>) + s = exists(s) + + # construct a select() at once + exists(['*'], **select_arguments).where(<criterion>) + + # columns argument is optional, generates "EXISTS (SELECT *)" + # by default. + exists().where(<criterion>) - The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by - itself or used as a subquery within an enclosing select. - - \*args, \**kwargs - all arguments are sent directly to the [sqlalchemy.sql.expression#select()] - function to produce a ``SELECT`` statement. """ - return _Exists(*args, **kwargs) def union(*selects, **kwargs): @@ -872,27 +878,36 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, ClauseElement) + return not isinstance(element, (ClauseElement, Operators)) + +def _from_objects(*elements, **kwargs): + return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements]) +def _labeled(element): + if not hasattr(element, 'name'): + return element.label(None) + else: + return element + def _literal_as_text(element): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): return _TextClause(unicode(element)) else: return element def _literal_as_column(element): - if isinstance(element, Operators): - return element.clause_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): return literal_column(str(element)) else: return element def _literal_as_binds(element, name=None, type_=None): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): if element is None: return null() @@ -902,17 +917,17 @@ def _literal_as_binds(element, name=None, type_=None): return element def _no_literals(element): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): - raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) else: return element def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) if not c: - raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) + raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) return c def _selectable(element): @@ -921,9 +936,8 @@ def _selectable(element): elif isinstance(element, Selectable): return element else: - raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) + raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) - def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) @@ -941,7 +955,9 @@ class _FigureVisitName(type): class ClauseElement(object): """Base class for elements of a programmatically constructed SQL expression.""" __metaclass__ = _FigureVisitName - + _annotations = {} + supports_execution = False + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -976,6 +992,14 @@ class ClauseElement(object): """ raise NotImplementedError(repr(self)) + + def _annotate(self, values): + """return a copy of this ClauseElement with the given annotations dictionary.""" + + global Annotated + if Annotated is None: + from sqlalchemy.sql.util import Annotated + return Annotated(self, values) def unique_params(self, *optionaldict, **kwargs): """Return a copy with ``bindparam()`` elments replaced. @@ -1006,14 +1030,14 @@ class ClauseElement(object): if len(optionaldict) == 1: kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: - raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument") + raise exc.ArgumentError("params() takes zero or one positional dictionary argument") def visit_bindparam(bind): if bind.key in kwargs: bind.value = kwargs[bind.key] if unique: bind._convert_to_unique() - return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True) + return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam}) def compare(self, other): """Compare this ClauseElement to the given ClauseElement. @@ -1049,11 +1073,6 @@ class ClauseElement(object): def self_group(self, against=None): return self - def supports_execution(self): - """Return True if this clause element represents a complete executable statement.""" - - return False - def bind(self): """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" @@ -1062,7 +1081,7 @@ class ClauseElement(object): return self._bind except AttributeError: pass - for f in self._get_from_objects(): + for f in _from_objects(self): if f is self: continue engine = f.bind @@ -1083,7 +1102,7 @@ class ClauseElement(object): 'Engine for execution. Or, assign a bind to the statement ' 'or the Metadata of its underlying tables to enable ' 'implicit execution via this method.' % label) - raise exceptions.UnboundExecutionError(msg) + raise exc.UnboundExecutionError(msg) return e.execute_clauseelement(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -1159,6 +1178,12 @@ class ClauseElement(object): self.__module__, self.__class__.__name__, id(self), friendly) +class _Immutable(object): + """mark a ClauseElement as 'immutable' when expressions are cloned.""" + + def _clone(self): + return self + class Operators(object): def __and__(self, other): return self.operate(operators.and_, other) @@ -1174,9 +1199,6 @@ class Operators(object): return self.operate(operators.op, opstring, b) return op - def clause_element(self): - raise NotImplementedError() - def operate(self, op, *other, **kwargs): raise NotImplementedError() @@ -1216,7 +1238,7 @@ class ColumnOperators(Operators): def ilike(self, other, escape=None): return self.operate(operators.ilike_op, other, escape=escape) - def in_(self, *other): + def in_(self, other): return self.operate(operators.in_op, other) def startswith(self, other, **kwargs): @@ -1279,18 +1301,18 @@ class _CompareMixin(ColumnOperators): def __compare(self, op, obj, negate=None, reverse=False, **kwargs): if obj is None or isinstance(obj, _Null): if op == operators.eq: - return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot) + return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot) elif op == operators.ne: - return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_) + return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_) else: - raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") + raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") else: obj = self._check_literal(obj) if reverse: - return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) else: - return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): obj = self._check_literal(obj) @@ -1298,9 +1320,9 @@ class _CompareMixin(ColumnOperators): type_ = self._compare_type(obj) if reverse: - return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_) + return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_) else: - return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_) + return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_) # a mapping of operators with the method they use, along with their negated # operator for comparison operators @@ -1329,17 +1351,10 @@ class _CompareMixin(ColumnOperators): o = _CompareMixin.operators[op] return o[0](self, op, other, reverse=True, *o[1:], **kwargs) - def in_(self, *other): - return self._in_impl(operators.in_op, operators.notin_op, *other) - - def _in_impl(self, op, negate_op, *other): - # Handle old style *args argument passing - if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)): - util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable') - seq_or_selectable = other - else: - seq_or_selectable = other[0] + def in_(self, other): + return self._in_impl(operators.in_op, operators.notin_op, other) + def _in_impl(self, op, negate_op, seq_or_selectable): if isinstance(seq_or_selectable, Selectable): return self.__compare( op, seq_or_selectable, negate=negate_op) @@ -1348,7 +1363,7 @@ class _CompareMixin(ColumnOperators): for o in seq_or_selectable: if not _is_literal(o): if not isinstance( o, _CompareMixin): - raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) + raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) else: o = self._bind_param(o) args.append(o) @@ -1433,22 +1448,13 @@ class _CompareMixin(ColumnOperators): if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): other.type = self.type return other - elif isinstance(other, Operators): - return other.expression_element() + elif hasattr(other, '__clause_element__'): + return other.__clause_element__() elif _is_literal(other): return self._bind_param(other) else: return other - def clause_element(self): - """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``.""" - return self - - def expression_element(self): - """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" - - return self - def _compare_type(self, obj): """Allow subclasses to override the type used in constructing ``_BinaryExpression`` objects. @@ -1480,23 +1486,22 @@ class ColumnElement(ClauseElement, _CompareMixin): primary_key = False foreign_keys = [] - + quote = None + def base_columns(self): - if hasattr(self, '_base_columns'): - return self._base_columns - self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) + if not hasattr(self, '_base_columns'): + self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) return self._base_columns base_columns = property(base_columns) def proxy_set(self): - if hasattr(self, '_proxy_set'): - return self._proxy_set - s = util.Set([self]) - if hasattr(self, 'proxies'): - for c in self.proxies: - s = s.union(c.proxy_set) - self._proxy_set = s - return s + if not hasattr(self, '_proxy_set'): + s = util.Set([self]) + if hasattr(self, 'proxies'): + for c in self.proxies: + s.update(c.proxy_set) + self._proxy_set = s + return self._proxy_set proxy_set = property(proxy_set) def shares_lineage(self, othercolumn): @@ -1518,7 +1523,7 @@ class ColumnElement(ClauseElement, _CompareMixin): co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None)) co.proxies = [self] - selectable.columns[name]= co + selectable.columns[name] = co return co def anon_label(self): @@ -1613,7 +1618,7 @@ class ColumnCollection(util.OrderedProperties): def __contains__(self, other): if not isinstance(other, basestring): - raise exceptions.ArgumentError("__contains__ requires a string argument") + raise exc.ArgumentError("__contains__ requires a string argument") return util.OrderedProperties.__contains__(self, other) def contains_column(self, col): @@ -1641,6 +1646,9 @@ class ColumnSet(util.OrderedSet): l.append(c==local) return and_(*l) + def __hash__(self): + return hash(tuple(self._list)) + class Selectable(ClauseElement): """mark a class as being selectable""" @@ -1648,8 +1656,9 @@ class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement.""" __visit_name__ = 'fromclause' - named_with_column=False + named_with_column = False _hide_froms = [] + quote = None def _get_from_objects(self, **modifiers): return [] @@ -1694,12 +1703,12 @@ class FromClause(Selectable): return fromclause in util.Set(self._cloned_set) def replace_selectable(self, old, alias): - """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" + """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" - global ClauseAdapter - if ClauseAdapter is None: - from sqlalchemy.sql.util import ClauseAdapter - return ClauseAdapter(alias).traverse(self, clone=True) + global ClauseAdapter + if ClauseAdapter is None: + from sqlalchemy.sql.util import ClauseAdapter + return ClauseAdapter(alias).traverse(self) def correspond_on_equivalents(self, column, equivalents): col = self.corresponding_column(column, require_embedded=True) @@ -1859,7 +1868,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): def _convert_to_unique(self): if not self.unique: - self.unique=True + self.unique = True self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param') def _get_from_objects(self, **modifiers): @@ -1910,6 +1919,7 @@ class _TextClause(ClauseElement): __visit_name__ = 'textclause' _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE) + supports_execution = True _hide_froms = [] oid_column = None @@ -1950,12 +1960,6 @@ class _TextClause(ClauseElement): def _get_from_objects(self, **modifiers): return [] - def supports_execution(self): - return True - - def _table_iterator(self): - return iter([]) - class _Null(ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -2042,6 +2046,7 @@ class _CalculatedClause(ColumnElement): __visit_name__ = 'calculatedclause' def __init__(self, name, *clauses, **kwargs): + ColumnElement.__init__(self) self.name = name self.type = sqltypes.to_instance(kwargs.get('type_', None)) self._bind = kwargs.get('bind', None) @@ -2061,7 +2066,7 @@ class _CalculatedClause(ColumnElement): def clauses(self): if isinstance(self.clause_expr, _Grouping): - return self.clause_expr.elem + return self.clause_expr.element else: return self.clause_expr clauses = property(clauses) @@ -2239,8 +2244,13 @@ class _Exists(_UnaryExpression): __visit_name__ = _UnaryExpression.__visit_name__ def __init__(self, *args, **kwargs): - kwargs['correlate'] = True - s = select(*args, **kwargs).as_scalar().self_group() + if args and isinstance(args[0], _SelectBaseMixin): + s = args[0] + else: + if not args: + args = ([literal_column('*')],) + s = select(*args, **kwargs).as_scalar().self_group() + _UnaryExpression.__init__(self, s, operator=operators.exists) def select(self, whereclause=None, **params): @@ -2272,7 +2282,7 @@ class Join(FromClause): self.right = _selectable(right).self_group() if onclause is None: - self.onclause = self.__match_primaries(self.left, self.right) + self.onclause = self._match_primaries(self.left, self.right) else: self.onclause = onclause @@ -2310,7 +2320,7 @@ class Join(FromClause): def get_children(self, **kwargs): return self.left, self.right, self.onclause - def __match_primaries(self, primary, secondary): + def _match_primaries(self, primary, secondary): global sql_util if not sql_util: from sqlalchemy.sql import util as sql_util @@ -2359,7 +2369,7 @@ class Join(FromClause): return self.select(use_labels=True, correlate=False).alias(name) def _hide_froms(self): - return itertools.chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set]) + return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set]) _hide_froms = property(_hide_froms) def _get_from_objects(self, **modifiers): @@ -2382,9 +2392,10 @@ class Alias(FromClause): def __init__(self, selectable, alias=None): baseselectable = selectable while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable + baseselectable = baseselectable.element self.original = baseselectable - self.selectable = selectable + self.supports_execution = baseselectable.supports_execution + self.element = selectable if alias is None: if self.original.named_with_column: alias = getattr(self.original, 'name', None) @@ -2398,112 +2409,100 @@ class Alias(FromClause): def is_derived_from(self, fromclause): if fromclause in util.Set(self._cloned_set): return True - return self.selectable.is_derived_from(fromclause) - - def supports_execution(self): - return self.original.supports_execution() - - def _table_iterator(self): - return self.original._table_iterator() + return self.element.is_derived_from(fromclause) def _populate_column_collection(self): - for col in self.selectable.columns: + for col in self.element.columns: col._make_proxy(self) - if self.selectable.oid_column is not None: - self._oid_column = self.selectable.oid_column._make_proxy(self) + if self.element.oid_column is not None: + self._oid_column = self.element.oid_column._make_proxy(self) def _copy_internals(self, clone=_clone): - self._reset_exported() - self.selectable = _clone(self.selectable) - baseselectable = self.selectable - while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable - self.original = baseselectable + self._reset_exported() + self.element = _clone(self.element) + baseselectable = self.element + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable def get_children(self, column_collections=True, aliased_selectables=True, **kwargs): if column_collections: for c in self.c: yield c if aliased_selectables: - yield self.selectable + yield self.element def _get_from_objects(self, **modifiers): return [self] def bind(self): - return self.selectable.bind + return self.element.bind bind = property(bind) -class _ColumnElementAdapter(ColumnElement): - """Adapts a ClauseElement which may or may not be a - ColumnElement subclass itself into an object which - acts like a ColumnElement. - """ +class _Grouping(ColumnElement): + """Represent a grouping within a column expression""" - def __init__(self, elem): - self.elem = elem - self.type = getattr(elem, 'type', None) + def __init__(self, element): + ColumnElement.__init__(self) + self.element = element + self.type = getattr(element, 'type', None) def key(self): - return self.elem.key + return self.element.key key = property(key) def _label(self): try: - return self.elem._label + return self.element._label except AttributeError: return self.anon_label _label = property(_label) def _copy_internals(self, clone=_clone): - self.elem = clone(self.elem) + self.element = clone(self.element) def get_children(self, **kwargs): - return self.elem, + return self.element, def _get_from_objects(self, **modifiers): - return self.elem._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def __getattr__(self, attr): - return getattr(self.elem, attr) + return getattr(self.element, attr) def __getstate__(self): - return {'elem':self.elem, 'type':self.type} + return {'element':self.element, 'type':self.type} def __setstate__(self, state): - self.elem = state['elem'] + self.element = state['element'] self.type = state['type'] -class _Grouping(_ColumnElementAdapter): - """Represent a grouping within a column expression""" - pass - class _FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' - def __init__(self, elem): - self.elem = elem + def __init__(self, element): + self.element = element def columns(self): - return self.elem.columns + return self.element.columns columns = c = property(columns) def _hide_froms(self): - return self.elem._hide_froms + return self.element._hide_froms _hide_froms = property(_hide_froms) def get_children(self, **kwargs): - return self.elem, + return self.element, def _copy_internals(self, clone=_clone): - self.elem = clone(self.elem) + self.element = clone(self.element) def _get_from_objects(self, **modifiers): - return self.elem._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def __getattr__(self, attr): - return getattr(self.elem, attr) + return getattr(self.element, attr) class _Label(ColumnElement): """Represents a column label (AS). @@ -2516,12 +2515,12 @@ class _Label(ColumnElement): ``ColumnElement`` subclasses. """ - def __init__(self, name, obj, type_=None): - while isinstance(obj, _Label): - obj = obj.obj - self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) - self.obj = obj.self_group(against=operators.as_) - self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) + def __init__(self, name, element, type_=None): + while isinstance(element, _Label): + element = element.element + self.name = name or "{ANON %d %s}" % (id(self), getattr(element, 'name', 'anon')) + self.element = element.self_group(against=operators.as_) + self.type = sqltypes.to_instance(type_ or getattr(element, 'type', None)) def key(self): return self.name @@ -2532,8 +2531,9 @@ class _Label(ColumnElement): _label = property(_label) def _proxy_attr(name): + get = util.attrgetter(name) def attr(self): - return getattr(self.obj, name) + return get(self.element) return property(attr) proxies = _proxy_attr('proxies') @@ -2542,27 +2542,24 @@ class _Label(ColumnElement): primary_key = _proxy_attr('primary_key') foreign_keys = _proxy_attr('foreign_keys') - def expression_element(self): - return self.obj - def get_children(self, **kwargs): - return self.obj, + return self.element, def _copy_internals(self, clone=_clone): - self.obj = clone(self.obj) + self.element = clone(self.element) def _get_from_objects(self, **modifiers): - return self.obj._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def _make_proxy(self, selectable, name = None): - if isinstance(self.obj, (Selectable, ColumnElement)): - e = self.obj._make_proxy(selectable, name=self.name) + if isinstance(self.element, (Selectable, ColumnElement)): + e = self.element._make_proxy(selectable, name=self.name) else: e = column(self.name)._make_proxy(selectable=selectable) e.proxies.append(self) return e -class _ColumnClause(ColumnElement): +class _ColumnClause(_Immutable, ColumnElement): """Represents a generic column expression from any textual string. This includes columns associated with tables, aliases and select @@ -2602,16 +2599,7 @@ class _ColumnClause(ColumnElement): return self.name.encode('ascii', 'backslashreplace') description = property(description) - def _clone(self): - # ColumnClause is immutable - return self - def _label(self): - """Generate a 'label' string for this column. - """ - - # for a "literal" column, we've no idea what the text is - # therefore no 'label' can be automatically generated if self.is_literal: return None if not self.__label: @@ -2626,24 +2614,21 @@ class _ColumnClause(ColumnElement): counter = 1 while label in self.table.c: label = self.__label + "_" + str(counter) - counter +=1 + counter += 1 self.__label = label else: self.__label = self.name return self.__label - _label = property(_label) def label(self, name): - # if going off the "__label" property and its None, we have - # no label; return self if name is None: return self else: return super(_ColumnClause, self).label(name) def _get_from_objects(self, **modifiers): - if self.table is not None: + if self.table: return [self.table] else: return [] @@ -2651,20 +2636,20 @@ class _ColumnClause(ColumnElement): def _bind_param(self, obj): return _BindParamClause(self.name, obj, type_=self.type, unique=True) - def _make_proxy(self, selectable, name = None): + def _make_proxy(self, selectable, name=None, attach=True): # propigate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) c.proxies = [self] - if not self._is_oid: + if attach and not self._is_oid: selectable.columns[c.name] = c return c def _compare_type(self, obj): return self.type -class TableClause(FromClause): +class TableClause(_Immutable, FromClause): """Represents a "table" construct. Note that this represents tables only as another syntactical @@ -2691,10 +2676,6 @@ class TableClause(FromClause): return self.name.encode('ascii', 'backslashreplace') description = property(description) - def _clone(self): - # TableClause is immutable - return self - def append_column(self, c): self._columns[c.name] = c c.table = self @@ -2724,10 +2705,11 @@ class TableClause(FromClause): def _get_from_objects(self, **modifiers): return [self] - class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" + supports_execution = True + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, autocommit=False): self.use_labels = use_labels self.for_update = for_update @@ -2773,11 +2755,6 @@ class _SelectBaseMixin(object): """ return self.as_scalar().label(name) - def supports_execution(self): - """part of the ClauseElement contract; returns ``True`` in all cases for this class.""" - - return True - def autocommit(self): """return a new selectable with the 'autocommit' flag set to True.""" @@ -2860,15 +2837,15 @@ class _SelectBaseMixin(object): class _ScalarSelect(_Grouping): __visit_name__ = 'grouping' - def __init__(self, elem): - self.elem = elem - cols = list(elem.inner_columns) + def __init__(self, element): + self.element = element + cols = list(element.inner_columns) if len(cols) != 1: - raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") + raise exc.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") self.type = cols[0].type def columns(self): - raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") + raise exc.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") columns = c = property(columns) def self_group(self, **kwargs): @@ -2893,7 +2870,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): if not numcols: numcols = len(s.c) elif len(s.c) != numcols: - raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % + raise exc.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % (1, len(self.selects[0].c), n+1, len(s.c)) ) if s._order_by_clause: @@ -2936,11 +2913,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause): return (column_collections and list(self.c) or []) + \ [self._order_by_clause, self._group_by_clause] + list(self.selects) - def _table_iterator(self): - for s in self.selects: - for t in s._table_iterator(): - yield t - def bind(self): if self._bind: return self._bind @@ -2976,6 +2948,7 @@ class Select(_SelectBaseMixin, FromClause): self._distinct = distinct self._correlate = util.Set() + self._froms = util.OrderedSet() if columns: self._raw_columns = [ @@ -2983,22 +2956,23 @@ class Select(_SelectBaseMixin, FromClause): for c in [_literal_as_column(c) for c in columns] ] + + self._froms.update(_from_objects(*self._raw_columns)) else: self._raw_columns = [] - - if from_obj: - self._froms = util.Set([ - _is_literal(f) and _TextClause(f) or f - for f in util.to_list(from_obj) - ]) - else: - self._froms = util.Set() - + if whereclause: self._whereclause = _literal_as_text(whereclause) + self._froms.update(_from_objects(self._whereclause, is_where=True)) else: self._whereclause = None + if from_obj: + self._froms.update([ + _is_literal(f) and _TextClause(f) or f + for f in util.to_list(from_obj) + ]) + if having: self._having = _literal_as_text(having) else: @@ -3020,36 +2994,28 @@ class Select(_SelectBaseMixin, FromClause): correlating. """ - froms = util.OrderedSet() - - for col in self._raw_columns: - froms.update(col._get_from_objects()) - - if self._whereclause is not None: - froms.update(self._whereclause._get_from_objects(is_where=True)) - - if self._froms: - froms.update(self._froms) + froms = self._froms toremove = itertools.chain(*[f._hide_froms for f in froms]) - froms.difference_update(toremove) + if toremove: + froms = froms.difference(toremove) if len(froms) > 1 or self._correlate: if self._correlate: - froms.difference_update(_cloned_intersection(froms, self._correlate)) + froms = froms.difference(_cloned_intersection(froms, self._correlate)) if self._should_correlate and existing_froms: - froms.difference_update(_cloned_intersection(froms, existing_froms)) + froms = froms.difference(_cloned_intersection(froms, existing_froms)) if not len(froms): - raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self) + raise exc.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self) return froms froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") def type(self): - raise exceptions.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") + raise exc.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") type = property(type) def locate_all_froms(self): @@ -3059,22 +3025,10 @@ class Select(_SelectBaseMixin, FromClause): is specifically for those FromClause elements that would actually be rendered. """ - if hasattr(self, '_all_froms'): - return self._all_froms - - froms = util.Set( - itertools.chain(* - [self._froms] + - [f._get_from_objects() for f in self._froms] + - [col._get_from_objects() for col in self._raw_columns] - ) - ) + if not hasattr(self, '_all_froms'): + self._all_froms = self._froms.union(_from_objects(*list(self._froms))) - if self._whereclause: - froms.update(self._whereclause._get_from_objects(is_where=True)) - - self._all_froms = froms - return froms + return self._all_froms def inner_columns(self): """an iteratorof all ColumnElement expressions which would @@ -3092,7 +3046,7 @@ class Select(_SelectBaseMixin, FromClause): def is_derived_from(self, fromclause): if self in util.Set(fromclause._cloned_set): return True - + for f in self.locate_all_froms(): if f.is_derived_from(fromclause): return True @@ -3112,7 +3066,7 @@ class Select(_SelectBaseMixin, FromClause): """return child elements as per the ClauseElement specification.""" return (column_collections and list(self.columns) or []) + \ - list(self.locate_all_froms()) + \ + list(self._froms) + \ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] def column(self, column): @@ -3125,6 +3079,7 @@ class Select(_SelectBaseMixin, FromClause): column = column.self_group(against=operators.comma_op) s._raw_columns = s._raw_columns + [column] + s._froms = s._froms.union(_from_objects(column)) return s def where(self, whereclause): @@ -3185,7 +3140,7 @@ class Select(_SelectBaseMixin, FromClause): """ s = self._generate() - s._should_correlate=False + s._should_correlate = False if fromclauses == (None,): s._correlate = util.Set() else: @@ -3195,7 +3150,7 @@ class Select(_SelectBaseMixin, FromClause): def append_correlation(self, fromclause): """append the given correlation expression to this select() construct.""" - self._should_correlate=False + self._should_correlate = False self._correlate = self._correlate.union([fromclause]) def append_column(self, column): @@ -3207,6 +3162,7 @@ class Select(_SelectBaseMixin, FromClause): column = column.self_group(against=operators.comma_op) self._raw_columns = self._raw_columns + [column] + self._froms = self._froms.union(_from_objects(column)) self._reset_exported() def append_prefix(self, clause): @@ -3221,10 +3177,13 @@ class Select(_SelectBaseMixin, FromClause): The expression will be joined to existing WHERE criterion via AND. """ + whereclause = _literal_as_text(whereclause) + self._froms = self._froms.union(_from_objects(whereclause, is_where=True)) + if self._whereclause is not None: - self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) + self._whereclause = and_(self._whereclause, whereclause) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = whereclause def append_having(self, having): """append the given expression to this select() construct's HAVING criterion. @@ -3311,31 +3270,23 @@ class Select(_SelectBaseMixin, FromClause): return intersect_all(self, other, **kwargs) - def _table_iterator(self): - for t in visitors.NoColumnVisitor().iterate(self): - if isinstance(t, TableClause): - yield t - def bind(self): if self._bind: return self._bind - for f in self._froms: - if f is self: - continue - e = f.bind - if e: - self._bind = e - return e - # look through the columns (largely synomous with looking - # through the FROMs except in the case of _CalculatedClause/_Function) - for c in self._raw_columns: - if getattr(c, 'table', None) is self: - continue - e = c.bind + if not self._froms: + for c in self._raw_columns: + e = c.bind + if e: + self._bind = e + return e + else: + e = list(self._froms)[0].bind if e: self._bind = e return e + return None + def _set_bind(self, bind): self._bind = bind bind = property(bind, _set_bind) @@ -3343,11 +3294,7 @@ class Select(_SelectBaseMixin, FromClause): class _UpdateBase(ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" - def supports_execution(self): - return True - - def _table_iterator(self): - return iter([self.table]) + supports_execution = True def _generate(self): s = self.__class__.__new__(self.__class__) @@ -3407,7 +3354,7 @@ class Insert(_ValuesBase): self._bind = bind self.table = table self.select = None - self.inline=inline + self.inline = inline if prefixes: self._prefixes = [_literal_as_text(p) for p in prefixes] else: @@ -3502,10 +3449,11 @@ class Delete(_UpdateBase): self._whereclause = clone(self._whereclause) class _IdentifiedClause(ClauseElement): + supports_execution = True + quote = None + def __init__(self, ident): self.ident = ident - def supports_execution(self): - return True class SavepointClause(_IdentifiedClause): pass |