summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/schema.py')
-rw-r--r--lib/sqlalchemy/sql/schema.py113
1 files changed, 112 insertions, 1 deletions
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 979b8319e..4ed5b9e6b 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -2233,6 +2233,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
server_default = self.server_default
server_onupdate = self.server_onupdate
if isinstance(server_default, (Computed, Identity)):
+ # TODO: likely should be copied in all cases
args.append(server_default._copy(**kw))
server_default = server_onupdate = None
@@ -2243,6 +2244,10 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
if self._user_defined_nullable is not NULL_UNSPECIFIED:
column_kwargs["nullable"] = self._user_defined_nullable
+ # TODO: DefaultGenerator is not copied here! it's just used again
+ # with _set_parent() pointing to the old column. see the new
+ # use of _copy() in the new _merge() method
+
c = self._constructor(
name=self.name,
type_=type_,
@@ -2264,6 +2269,69 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
)
return self._schema_item_copy(c)
+ def _merge(self, other: Column[Any]) -> None:
+ """merge the elements of another column into this one.
+
+ this is used by ORM pep-593 merge and will likely need a lot
+ of fixes.
+
+
+ """
+
+ if self.primary_key:
+ other.primary_key = True
+
+ type_ = self.type
+ if not type_._isnull and other.type._isnull:
+ if isinstance(type_, SchemaEventTarget):
+ type_ = type_.copy()
+
+ other.type = type_
+
+ if isinstance(type_, SchemaEventTarget):
+ type_._set_parent_with_dispatch(other)
+
+ for impl in type_._variant_mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent_with_dispatch(other)
+
+ if (
+ self._user_defined_nullable is not NULL_UNSPECIFIED
+ and other._user_defined_nullable is NULL_UNSPECIFIED
+ ):
+ other.nullable = self.nullable
+
+ if self.default is not None and other.default is None:
+ new_default = self.default._copy()
+ new_default._set_parent(other)
+
+ if self.server_default and other.server_default is None:
+ new_server_default = self.server_default
+ if isinstance(new_server_default, FetchedValue):
+ new_server_default = new_server_default._copy()
+ new_server_default._set_parent(other)
+ else:
+ other.server_default = new_server_default
+
+ if self.server_onupdate and other.server_onupdate is None:
+ new_server_onupdate = self.server_onupdate
+ new_server_onupdate = new_server_onupdate._copy()
+ new_server_onupdate._set_parent(other)
+
+ if self.onupdate and other.onupdate is None:
+ new_onupdate = self.onupdate._copy()
+ new_onupdate._set_parent(other)
+
+ for const in self.constraints:
+ if not const._type_bound:
+ new_const = const._copy()
+ new_const._set_parent(other)
+
+ for fk in self.foreign_keys:
+ if not fk.constraint:
+ new_fk = fk._copy()
+ new_fk._set_parent(other)
+
def _make_proxy(
self,
selectable: FromClause,
@@ -2948,6 +3016,9 @@ class DefaultGenerator(Executable, SchemaItem):
else:
self.column.default = self
+ def _copy(self) -> DefaultGenerator:
+ raise NotImplementedError()
+
def _execute_on_connection(
self,
connection: Connection,
@@ -3077,6 +3148,11 @@ class ScalarElementColumnDefault(ColumnDefault):
self.for_update = for_update
self.arg = arg
+ def _copy(self) -> ScalarElementColumnDefault:
+ return ScalarElementColumnDefault(
+ arg=self.arg, for_update=self.for_update
+ )
+
# _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"]
_SQLExprDefault = Union["ColumnElement[Any]", "TextClause"]
@@ -3101,6 +3177,11 @@ class ColumnElementColumnDefault(ColumnDefault):
self.for_update = for_update
self.arg = arg
+ def _copy(self) -> ColumnElementColumnDefault:
+ return ColumnElementColumnDefault(
+ arg=self.arg, for_update=self.for_update
+ )
+
@util.memoized_property
@util.preload_module("sqlalchemy.sql.sqltypes")
def _arg_is_typed(self) -> bool:
@@ -3132,6 +3213,9 @@ class CallableColumnDefault(ColumnDefault):
self.for_update = for_update
self.arg = self._maybe_wrap_callable(arg)
+ def _copy(self) -> CallableColumnDefault:
+ return CallableColumnDefault(arg=self.arg, for_update=self.for_update)
+
def _maybe_wrap_callable(
self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]]
) -> _CallableColumnDefaultProtocol:
@@ -3266,7 +3350,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
nomaxvalue: Optional[bool] = None,
cycle: Optional[bool] = None,
schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None,
- cache: Optional[bool] = None,
+ cache: Optional[int] = None,
order: Optional[bool] = None,
data_type: Optional[_TypeEngineArgument[int]] = None,
optional: bool = False,
@@ -3459,6 +3543,25 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
super(Sequence, self)._set_parent(column)
column._on_table_attach(self._set_table)
+ def _copy(self) -> Sequence:
+ return Sequence(
+ name=self.name,
+ start=self.start,
+ increment=self.increment,
+ minvalue=self.minvalue,
+ maxvalue=self.maxvalue,
+ nominvalue=self.nominvalue,
+ nomaxvalue=self.nomaxvalue,
+ cycle=self.cycle,
+ schema=self.schema,
+ cache=self.cache,
+ order=self.order,
+ data_type=self.data_type,
+ optional=self.optional,
+ metadata=self.metadata,
+ for_update=self.for_update,
+ )
+
def _set_table(self, column: Column[Any], table: Table) -> None:
self._set_metadata(table.metadata)
@@ -3522,6 +3625,9 @@ class FetchedValue(SchemaEventTarget):
else:
return self._clone(for_update) # type: ignore
+ def _copy(self) -> FetchedValue:
+ return FetchedValue(self.for_update)
+
def _clone(self, for_update: bool) -> Any:
n = self.__class__.__new__(self.__class__)
n.__dict__.update(self.__dict__)
@@ -3577,6 +3683,11 @@ class DefaultClause(FetchedValue):
self.arg = arg
self.reflected = _reflected
+ def _copy(self) -> DefaultClause:
+ return DefaultClause(
+ arg=self.arg, for_update=self.for_update, _reflected=self.reflected
+ )
+
def __repr__(self) -> str:
return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)