diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/schema.py | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index b18a6e365..34bedbc6a 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -92,6 +92,9 @@ def _get_table_key(name, schema): # this should really be in sql/util.py but we'd have to # break an import cycle def _copy_expression(expression, source_table, target_table): + if source_table is None or target_table is None: + return expression + def replace(col): if ( isinstance(col, Column) @@ -3272,7 +3275,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): def __contains__(self, x): return x in self.columns - def copy(self, **kw): + def copy(self, target_table=None, **kw): # ticket #5276 constraint_kwargs = {} for dialect_name in self.dialect_options: @@ -3289,7 +3292,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): name=self.name, deferrable=self.deferrable, initially=self.initially, - *self.columns.keys(), + *[ + _copy_expression(expr, self.parent, target_table) + for expr in self.columns + ], **constraint_kwargs ) return self._schema_item_copy(c) @@ -3393,6 +3399,9 @@ class CheckConstraint(ColumnCollectionConstraint): def copy(self, target_table=None, **kw): if target_table is not None: + # note that target_table is None for the copy process of + # a column-bound CheckConstraint, so this path is not reached + # in that case. sqltext = _copy_expression(self.sqltext, self.table, target_table) else: sqltext = self.sqltext @@ -4864,10 +4873,11 @@ class Computed(FetchedValue, SchemaItem): return self def copy(self, target_table=None, **kw): - if target_table is not None: - sqltext = _copy_expression(self.sqltext, self.table, target_table) - else: - sqltext = self.sqltext + sqltext = _copy_expression( + self.sqltext, + self.column.table if self.column is not None else None, + target_table, + ) g = Computed(sqltext, persisted=self.persisted) return self._schema_item_copy(g) @@ -4998,7 +5008,7 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): def _as_for_update(self, for_update): return self - def copy(self, target_table=None, **kw): + def copy(self, **kw): i = Identity( always=self.always, on_null=self.on_null, |