diff options
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 116 |
1 files changed, 62 insertions, 54 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 071bb3c50..fa0586e2d 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1216,7 +1216,7 @@ def _string_or_unprintable(element): except: return "unprintable element %r" % element -def _clone(element): +def _clone(element, **kw): return element._clone() def _expand_cloned(elements): @@ -1522,12 +1522,16 @@ class ClauseElement(Visitable): """ return self is other - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly shallow-copied elements to create a deep copy. + The given clone function should be used, which may be applying + additional transformations to the element (i.e. replacement + traversal, cloned traversal, annotations). + """ pass @@ -2755,8 +2759,8 @@ class _TextClause(Executable, ClauseElement): else: return self - def _copy_internals(self, clone=_clone): - self.bindparams = dict((b.key, clone(b)) + def _copy_internals(self, clone=_clone, **kw): + self.bindparams = dict((b.key, clone(b, **kw)) for b in self.bindparams.values()) def get_children(self, **kwargs): @@ -2846,8 +2850,8 @@ class ClauseList(ClauseElement): else: self.clauses.append(_literal_as_text(clause)) - def _copy_internals(self, clone=_clone): - self.clauses = [clone(clause) for clause in self.clauses] + def _copy_internals(self, clone=_clone, **kw): + self.clauses = [clone(clause, **kw) for clause in self.clauses] def get_children(self, **kwargs): return self.clauses @@ -2947,12 +2951,13 @@ class _Case(ColumnElement): else: self.else_ = None - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): if self.value is not None: - self.value = clone(self.value) - self.whens = [(clone(x), clone(y)) for x, y in self.whens] + self.value = clone(self.value, **kw) + self.whens = [(clone(x, **kw), clone(y, **kw)) + for x, y in self.whens] if self.else_ is not None: - self.else_ = clone(self.else_) + self.else_ = clone(self.else_, **kw) def get_children(self, **kwargs): if self.value is not None: @@ -3028,8 +3033,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): def get_children(self, **kwargs): return self.clause_expr, - def _copy_internals(self, clone=_clone): - self.clause_expr = clone(self.clause_expr) + def _copy_internals(self, clone=_clone, **kw): + self.clause_expr = clone(self.clause_expr, **kw) self._reset_exported() util.reset_memoized(self, 'clauses') @@ -3120,9 +3125,9 @@ class _Cast(ColumnElement): self.clause = _literal_as_binds(clause, None) self.typeclause = _TypeClause(self.type) - def _copy_internals(self, clone=_clone): - self.clause = clone(self.clause) - self.typeclause = clone(self.typeclause) + def _copy_internals(self, clone=_clone, **kw): + self.clause = clone(self.clause, **kw) + self.typeclause = clone(self.typeclause, **kw) def get_children(self, **kwargs): return self.clause, self.typeclause @@ -3141,8 +3146,8 @@ class _Extract(ColumnElement): self.field = field self.expr = _literal_as_binds(expr, None) - def _copy_internals(self, clone=_clone): - self.expr = clone(self.expr) + def _copy_internals(self, clone=_clone, **kw): + self.expr = clone(self.expr, **kw) def get_children(self, **kwargs): return self.expr, @@ -3170,8 +3175,8 @@ class _UnaryExpression(ColumnElement): def _from_objects(self): return self.element._from_objects - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) def get_children(self, **kwargs): return self.element, @@ -3233,9 +3238,9 @@ class _BinaryExpression(ColumnElement): def _from_objects(self): return self.left._from_objects + self.right._from_objects - def _copy_internals(self, clone=_clone): - self.left = clone(self.left) - self.right = clone(self.right) + def _copy_internals(self, clone=_clone, **kw): + self.left = clone(self.left, **kw) + self.right = clone(self.right, **kw) def get_children(self, **kwargs): return self.left, self.right @@ -3373,11 +3378,11 @@ class Join(FromClause): self.foreign_keys.update(itertools.chain( *[col.foreign_keys for col in columns])) - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): self._reset_exported() - self.left = clone(self.left) - self.right = clone(self.right) - self.onclause = clone(self.onclause) + self.left = clone(self.left, **kw) + self.right = clone(self.right, **kw) + self.onclause = clone(self.onclause, **kw) self.__folded_equivalents = None def get_children(self, **kwargs): @@ -3525,21 +3530,24 @@ class Alias(FromClause): for col in self.element.columns: col._make_proxy(self) - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): + # don't apply anything to an aliased Table + # for now. May want to drive this from + # the given **kw. + if isinstance(self.element, TableClause): + return self._reset_exported() - self.element = _clone(self.element) + self.element = clone(self.element, **kw) baseselectable = self.element while isinstance(baseselectable, Alias): baseselectable = baseselectable.element self.original = baseselectable - def get_children(self, column_collections=True, - aliased_selectables=True, **kwargs): + def get_children(self, column_collections=True, **kw): if column_collections: for c in self.c: yield c - if aliased_selectables: - yield self.element + yield self.element @property def _from_objects(self): @@ -3563,8 +3571,8 @@ class _Grouping(ColumnElement): def _label(self): return getattr(self.element, '_label', None) or self.anon_label - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) def get_children(self, **kwargs): return self.element, @@ -3615,8 +3623,8 @@ class _FromGrouping(FromClause): def get_children(self, **kwargs): return self.element, - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) @property def _from_objects(self): @@ -3662,12 +3670,12 @@ class _Over(ColumnElement): (self.func, self.partition_by, self.order_by) if c is not None] - def _copy_internals(self, clone=_clone): - self.func = clone(self.func) + def _copy_internals(self, clone=_clone, **kw): + self.func = clone(self.func, **kw) if self.partition_by is not None: - self.partition_by = clone(self.partition_by) + self.partition_by = clone(self.partition_by, **kw) if self.order_by is not None: - self.order_by = clone(self.order_by) + self.order_by = clone(self.order_by, **kw) @property def _from_objects(self): @@ -3732,8 +3740,8 @@ class _Label(ColumnElement): def get_children(self, **kwargs): return self.element, - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) @property def _from_objects(self): @@ -4244,14 +4252,14 @@ class CompoundSelect(_SelectBase): proxy.proxies = [c._annotate({'weight': i + 1}) for (i, c) in enumerate(cols)] - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): self._reset_exported() - self.selects = [clone(s) for s in self.selects] + self.selects = [clone(s, **kw) for s in self.selects] if hasattr(self, '_col_map'): del self._col_map for attr in ('_order_by_clause', '_group_by_clause'): if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr))) + setattr(self, attr, clone(getattr(self, attr), **kw)) def get_children(self, column_collections=True, **kwargs): return (column_collections and list(self.c) or []) \ @@ -4477,17 +4485,17 @@ class Select(_SelectBase): return True return False - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): self._reset_exported() - from_cloned = dict((f, clone(f)) + from_cloned = dict((f, clone(f, **kw)) for f in self._froms.union(self._correlate)) self._froms = util.OrderedSet(from_cloned[f] for f in self._froms) self._correlate = set(from_cloned[f] for f in self._correlate) - self._raw_columns = [clone(c) for c in self._raw_columns] + self._raw_columns = [clone(c, **kw) for c in self._raw_columns] for attr in '_whereclause', '_having', '_order_by_clause', \ '_group_by_clause': if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr))) + setattr(self, attr, clone(getattr(self, attr), **kw)) def get_children(self, column_collections=True, **kwargs): """return child elements as per the ClauseElement specification.""" @@ -4910,7 +4918,7 @@ class Insert(ValuesBase): else: return () - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): # TODO: coverage self.parameters = self.parameters.copy() @@ -4959,9 +4967,9 @@ class Update(ValuesBase): else: return () - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): # TODO: coverage - self._whereclause = clone(self._whereclause) + self._whereclause = clone(self._whereclause, **kw) self.parameters = self.parameters.copy() @_generative @@ -5020,9 +5028,9 @@ class Delete(UpdateBase): else: self._whereclause = _literal_as_text(whereclause) - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): # TODO: coverage - self._whereclause = clone(self._whereclause) + self._whereclause = clone(self._whereclause, **kw) class _IdentifiedClause(Executable, ClauseElement): |