diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-08-19 17:38:07 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-08-19 17:38:07 +0000 |
commit | 62347923754754f93adbf0c3888208be77f26e70 (patch) | |
tree | 59f0a84b28397d5ba3d0095807a6bf97b2ecc892 /lib/sqlalchemy/dialects/postgresql/base.py | |
parent | 348afaf742d0df017f9ae0c71c981de0fb967780 (diff) | |
parent | d1005e130558b33fd455be6d994e8e863799a318 (diff) | |
download | sqlalchemy-62347923754754f93adbf0c3888208be77f26e70.tar.gz |
Merge "Implement DDL visitor for PG ENUM with schema translate support"
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 57 |
1 files changed, 47 insertions, 10 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 92e48f1f3..2db079799 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1011,6 +1011,7 @@ from ...sql import expression from ...sql import roles from ...sql import sqltypes from ...sql import util as sql_util +from ...sql.ddl import DDLBase from ...types import BIGINT from ...types import BOOLEAN from ...types import CHAR @@ -1499,10 +1500,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or not bind.dialect.has_type( - bind, self.name, schema=self.schema - ): - bind.execute(CreateEnumType(self)) + bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=True): """Emit ``DROP TYPE`` for this @@ -1522,10 +1520,49 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or bind.dialect.has_type( - bind, self.name, schema=self.schema - ): - bind.execute(DropEnumType(self)) + bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) + + class EnumGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return not self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_create_enum(enum): + return + + self.connection.execute(CreateEnumType(enum)) + + class EnumDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_drop_enum(enum): + return + + self.connection.execute(DropEnumType(enum)) def _check_for_name_in_memos(self, checkfirst, kw): """Look in the 'ddl runner' for 'memos', then @@ -1551,14 +1588,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): return False def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( checkfirst or ( not self.metadata and not kw.get("_is_metadata_operation", False) ) - and not self._check_for_name_in_memos(checkfirst, kw) - ): + ) and not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) def _on_table_drop(self, target, bind, checkfirst=False, **kw): |