diff options
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 187 |
1 files changed, 75 insertions, 112 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 22c296e98..51bd176c3 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1392,36 +1392,32 @@ class ColumnElement(ClauseElement, _CompareMixin): return None foreign_key = property(_one_fkey) - - def _get_orig_set(self): - try: - return self.__orig_set - except AttributeError: - self.__orig_set = util.Set([self]) - return self.__orig_set - - def _set_orig_set(self, s): - if len(s) == 0: - s.add(self) - self.__orig_set = s - - orig_set = property(_get_orig_set, _set_orig_set, - doc=\ - """A Set containing TableClause-bound, non-proxied ColumnElements - for which this ColumnElement is a proxy. In all cases except - for a column proxied from a Union (i.e. CompoundSelect), this - set will be just one element. - """) - + + def base_column(self): + if hasattr(self, '_base_column'): + return self._base_column + p = self + while hasattr(p, 'proxies'): + p = p.proxies[0] + self._base_column = p + return p + base_column = property(base_column) + + def proxy_set(self): + if hasattr(self, '_proxy_set'): + return self._proxy_set + s = util.Set([self]) + if hasattr(self, 'proxies'): + for c in self.proxies: + s = s.union(c.proxy_set) + self._proxy_set = s + return s + proxy_set = property(proxy_set) + def shares_lineage(self, othercolumn): """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``. """ - - for c in self.orig_set: - if c in othercolumn.orig_set: - return True - else: - return False + return len(self.proxy_set.intersection(othercolumn.proxy_set)) > 0 def _make_proxy(self, selectable, name=None): """Create a new ``ColumnElement`` representing this @@ -1434,7 +1430,7 @@ class ColumnElement(ClauseElement, _CompareMixin): if name is not None: co = _ColumnClause(name, selectable) - co.orig_set = self.orig_set + co.proxies = [self] selectable.columns[name]= co return co else: @@ -1569,14 +1565,6 @@ class FromClause(Selectable): return False - def _get_all_embedded_columns(self): - ret = [] - class FindCols(visitors.ClauseVisitor): - def visit_column(self, col): - ret.append(col) - FindCols().traverse(self) - return ret - def is_derived_from(self, fromclause): """Return True if this FromClause is 'derived' from the given FromClause. @@ -1616,19 +1604,20 @@ class FromClause(Selectable): of this ``FromClause``. """ - if self.c.contains_column(column): - return column - - if require_embedded and column not in util.Set(self._get_all_embedded_columns()): + if require_embedded and column not in self._get_all_embedded_columns(): if not raiseerr: return None else: - raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table)) - for c in column.orig_set: - try: - return self.original_columns[c] - except KeyError: - pass + raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table.description)) + + col, intersect = None, None + target_set = column.proxy_set + for c in self.c + [self.oid_column]: + i = c.proxy_set.intersection(target_set) + if i and (intersect is None or len(i) > len(intersect)): + col, intersect = c, i + if col: + return col else: if keys_ok: try: @@ -1638,18 +1627,33 @@ class FromClause(Selectable): if not raiseerr: return None else: - raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.name)) + raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.description)) + def description(self): + return getattr(self, 'name', self.__class__.__name__ + " object") + description = property(description) + def _clone_from_clause(self): # delete all the "generated" collections of columns for a # newly cloned FromClause, so that they will be re-derived # from the item. this is because FromClause subclasses, when # cloned, need to reestablish new "proxied" columns that are # linked to the new item - for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'): + for attr in ('_columns', '_primary_key' '_foreign_keys', '_oid_column', '_embedded_columns'): if hasattr(self, attr): delattr(self, attr) + def _get_all_embedded_columns(self): + if hasattr(self, '_embedded_columns'): + return self._embedded_columns + ret = util.Set() + class FindCols(visitors.ClauseVisitor): + def visit_column(self, col): + ret.add(col) + FindCols().traverse(self) + self._embedded_columns = ret + return ret + def _expr_attr_func(name): def attr(self): try: @@ -1663,22 +1667,10 @@ class FromClause(Selectable): c = property(_expr_attr_func('_columns')) primary_key = property(_expr_attr_func('_primary_key')) foreign_keys = property(_expr_attr_func('_foreign_keys')) - original_columns = property(_expr_attr_func('_orig_cols'), doc=\ - """A dictionary mapping an original Table-bound - column to a proxied column in this FromClause. - """) def _export_columns(self, columns=None): """Initialize column collections. - The collections include the primary key, foreign keys, list of - all columns, as well as the *_orig_cols* collection which is a - dictionary used to match Table-bound columns to proxied - columns in this ``FromClause``. The columns in each - collection are *proxied* from the columns returned by the - _exportable_columns method, where a *proxied* column maintains - most or all of the properties of its original column, except - its parent ``Selectable`` is this ``FromClause``. """ if hasattr(self, '_columns') and columns is None: @@ -1686,24 +1678,11 @@ class FromClause(Selectable): self._columns = ColumnCollection() self._primary_key = ColumnSet() self._foreign_keys = util.Set() - self._orig_cols = {} if columns is None: columns = self._flatten_exportable_columns() for co in columns: cp = self._proxy_column(co) - for ci in cp.orig_set: - cx = self._orig_cols.get(ci) - # TODO: the '=' thing here relates to the order of - # columns as they are placed in the "columns" - # collection of a CompositeSelect, illustrated in - # test/sql/selectable.SelectableTest.testunion make - # this relationship less brittle - if cx is None or cp._distance <= cx._distance: - self._orig_cols[ci] = cp - if self.oid_column is not None: - for ci in self.oid_column.orig_set: - self._orig_cols[ci] = self.oid_column def _flatten_exportable_columns(self): """Return the list of ColumnElements represented within this FromClause's _exportable_columns""" @@ -2058,7 +2037,6 @@ class _Cast(ColumnElement): self.type = sqltypes.to_instance(totype) self.clause = clause self.typeclause = _TypeClause(self.type) - self._distance = 0 def _copy_internals(self, clone=_clone): self.clause = clone(self.clause) @@ -2073,8 +2051,7 @@ class _Cast(ColumnElement): def _make_proxy(self, selectable, name=None): if name is not None: co = _ColumnClause(name, selectable, type_=self.type) - co._distance = self._distance + 1 - co.orig_set = self.orig_set + co.proxies = [self] selectable.columns[name]= co return co else: @@ -2251,6 +2228,10 @@ class Join(FromClause): self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit]) + def description(self): + return "Join object on %s and %s" % (self.left.description, self.right.description) + description = property(description) + primary_key = property(lambda s:s.__primary_key) def self_group(self, against=None): @@ -2294,14 +2275,14 @@ class Join(FromClause): if len(crit) == 0: raise exceptions.ArgumentError( "Can't find any foreign key relationships " - "between '%s' and '%s'" % (primary.name, secondary.name)) + "between '%s' and '%s'" % (primary.description, secondary.description)) elif len(constraints) > 1: raise exceptions.ArgumentError( "Can't determine join between '%s' and '%s'; " "tables have more than one foreign key " "constraint relationship between them. " "Please specify the 'onclause' of this " - "join explicitly." % (primary.name, secondary.name)) + "join explicitly." % (primary.description, secondary.description)) elif len(crit) == 1: return (crit[0]) else: @@ -2456,7 +2437,6 @@ class _ColumnElementAdapter(ColumnElement): def __init__(self, elem): self.elem = elem self.type = getattr(elem, 'type', None) - self.orig_set = getattr(elem, 'orig_set', util.Set()) key = property(lambda s: s.elem.key) _label = property(lambda s: s.elem._label) @@ -2477,12 +2457,11 @@ class _ColumnElementAdapter(ColumnElement): return getattr(self.elem, attr) def __getstate__(self): - return {'elem':self.elem, 'type':self.type, 'orig_set':self.orig_set} + return {'elem':self.elem, 'type':self.type} def __setstate__(self, state): self.elem = state['elem'] self.type = state['type'] - self.orig_set = state['orig_set'] class _Grouping(_ColumnElementAdapter): """Represent a grouping within a column expression""" @@ -2527,14 +2506,15 @@ class _Label(ColumnElement): while isinstance(obj, _Label): obj = obj.obj self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) - self.obj = obj.self_group(against=operators.as_) self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) key = property(lambda s: s.name) _label = property(lambda s: s.name) - orig_set = property(lambda s:s.obj.orig_set) - + proxies = property(lambda s:s.obj.proxies) + base_column = property(lambda s:s.obj.base_column) + proxy_set = property(lambda s:s.obj.proxy_set) + def expression_element(self): return self.obj @@ -2589,7 +2569,6 @@ class _ColumnClause(ColumnElement): self.table = selectable self.type = sqltypes.to_instance(type_) self._is_oid = _is_oid - self._distance = 0 self.__label = None self.is_literal = is_literal @@ -2621,8 +2600,6 @@ class _ColumnClause(ColumnElement): self.__label = self.name return self.__label - is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name) - _label = property(_get_label) def label(self, name): @@ -2647,8 +2624,7 @@ class _ColumnClause(ColumnElement): # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) - c.orig_set = self.orig_set - c._distance = self._distance + 1 + c.proxies = [self] if not self._is_oid: selectable.columns[c.name] = c return c @@ -2686,18 +2662,6 @@ class TableClause(FromClause): self.append_column(c) return c - def _orig_columns(self): - try: - return self._orig_cols - except AttributeError: - self._orig_cols= {} - for c in self.columns: - for ci in c.orig_set: - self._orig_cols[ci] = c - return self._orig_cols - - original_columns = property(_orig_columns) - def get_children(self, column_collections=True, **kwargs): if column_collections: return [c for c in self.c] @@ -2922,18 +2886,17 @@ class CompoundSelect(_SelectBaseMixin, FromClause): yield c def _proxy_column(self, column): - if self.use_labels: - col = column._make_proxy(self, name=column._label) + existing = self._col_map.get(column.name, None) + if existing is not None: + existing.proxies.append(column) + return existing else: - col = column._make_proxy(self) - try: - colset = self._col_map[col.name] - except KeyError: - colset = util.Set() - self._col_map[col.name] = colset - [colset.add(c) for c in col.orig_set] - col.orig_set = colset - return col + if self.use_labels: + col = column._make_proxy(self, name=column._label) + else: + col = column._make_proxy(self) + self._col_map[col.name] = col + return col def _copy_internals(self, clone=_clone): self._clone_from_clause() |