diff options
Diffstat (limited to 'lib/sqlalchemy/sql/schema.py')
-rw-r--r-- | lib/sqlalchemy/sql/schema.py | 113 |
1 files changed, 112 insertions, 1 deletions
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 979b8319e..4ed5b9e6b 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2233,6 +2233,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): server_default = self.server_default server_onupdate = self.server_onupdate if isinstance(server_default, (Computed, Identity)): + # TODO: likely should be copied in all cases args.append(server_default._copy(**kw)) server_default = server_onupdate = None @@ -2243,6 +2244,10 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): if self._user_defined_nullable is not NULL_UNSPECIFIED: column_kwargs["nullable"] = self._user_defined_nullable + # TODO: DefaultGenerator is not copied here! it's just used again + # with _set_parent() pointing to the old column. see the new + # use of _copy() in the new _merge() method + c = self._constructor( name=self.name, type_=type_, @@ -2264,6 +2269,69 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): ) return self._schema_item_copy(c) + def _merge(self, other: Column[Any]) -> None: + """merge the elements of another column into this one. + + this is used by ORM pep-593 merge and will likely need a lot + of fixes. + + + """ + + if self.primary_key: + other.primary_key = True + + type_ = self.type + if not type_._isnull and other.type._isnull: + if isinstance(type_, SchemaEventTarget): + type_ = type_.copy() + + other.type = type_ + + if isinstance(type_, SchemaEventTarget): + type_._set_parent_with_dispatch(other) + + for impl in type_._variant_mapping.values(): + if isinstance(impl, SchemaEventTarget): + impl._set_parent_with_dispatch(other) + + if ( + self._user_defined_nullable is not NULL_UNSPECIFIED + and other._user_defined_nullable is NULL_UNSPECIFIED + ): + other.nullable = self.nullable + + if self.default is not None and other.default is None: + new_default = self.default._copy() + new_default._set_parent(other) + + if self.server_default and other.server_default is None: + new_server_default = self.server_default + if isinstance(new_server_default, FetchedValue): + new_server_default = new_server_default._copy() + new_server_default._set_parent(other) + else: + other.server_default = new_server_default + + if self.server_onupdate and other.server_onupdate is None: + new_server_onupdate = self.server_onupdate + new_server_onupdate = new_server_onupdate._copy() + new_server_onupdate._set_parent(other) + + if self.onupdate and other.onupdate is None: + new_onupdate = self.onupdate._copy() + new_onupdate._set_parent(other) + + for const in self.constraints: + if not const._type_bound: + new_const = const._copy() + new_const._set_parent(other) + + for fk in self.foreign_keys: + if not fk.constraint: + new_fk = fk._copy() + new_fk._set_parent(other) + def _make_proxy( self, selectable: FromClause, @@ -2948,6 +3016,9 @@ class DefaultGenerator(Executable, SchemaItem): else: self.column.default = self + def _copy(self) -> DefaultGenerator: + raise NotImplementedError() + def _execute_on_connection( self, connection: Connection, @@ -3077,6 +3148,11 @@ class ScalarElementColumnDefault(ColumnDefault): self.for_update = for_update self.arg = arg + def _copy(self) -> ScalarElementColumnDefault: + return ScalarElementColumnDefault( + arg=self.arg, for_update=self.for_update + ) + # _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"] _SQLExprDefault = Union["ColumnElement[Any]", "TextClause"] @@ -3101,6 +3177,11 @@ class ColumnElementColumnDefault(ColumnDefault): self.for_update = for_update self.arg = arg + def _copy(self) -> ColumnElementColumnDefault: + return ColumnElementColumnDefault( + arg=self.arg, for_update=self.for_update + ) + @util.memoized_property @util.preload_module("sqlalchemy.sql.sqltypes") def _arg_is_typed(self) -> bool: @@ -3132,6 +3213,9 @@ class CallableColumnDefault(ColumnDefault): self.for_update = for_update self.arg = self._maybe_wrap_callable(arg) + def _copy(self) -> CallableColumnDefault: + return CallableColumnDefault(arg=self.arg, for_update=self.for_update) + def _maybe_wrap_callable( self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]] ) -> _CallableColumnDefaultProtocol: @@ -3266,7 +3350,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): nomaxvalue: Optional[bool] = None, cycle: Optional[bool] = None, schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None, - cache: Optional[bool] = None, + cache: Optional[int] = None, order: Optional[bool] = None, data_type: Optional[_TypeEngineArgument[int]] = None, optional: bool = False, @@ -3459,6 +3543,25 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): super(Sequence, self)._set_parent(column) column._on_table_attach(self._set_table) + def _copy(self) -> Sequence: + return Sequence( + name=self.name, + start=self.start, + increment=self.increment, + minvalue=self.minvalue, + maxvalue=self.maxvalue, + nominvalue=self.nominvalue, + nomaxvalue=self.nomaxvalue, + cycle=self.cycle, + schema=self.schema, + cache=self.cache, + order=self.order, + data_type=self.data_type, + optional=self.optional, + metadata=self.metadata, + for_update=self.for_update, + ) + def _set_table(self, column: Column[Any], table: Table) -> None: self._set_metadata(table.metadata) @@ -3522,6 +3625,9 @@ class FetchedValue(SchemaEventTarget): else: return self._clone(for_update) # type: ignore + def _copy(self) -> FetchedValue: + return FetchedValue(self.for_update) + def _clone(self, for_update: bool) -> Any: n = self.__class__.__new__(self.__class__) n.__dict__.update(self.__dict__) @@ -3577,6 +3683,11 @@ class DefaultClause(FetchedValue): self.arg = arg self.reflected = _reflected + def _copy(self) -> DefaultClause: + return DefaultClause( + arg=self.arg, for_update=self.for_update, _reflected=self.reflected + ) + def __repr__(self) -> str: return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) |