summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/base.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-08-19 17:38:07 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-08-19 17:38:07 +0000
commit62347923754754f93adbf0c3888208be77f26e70 (patch)
tree59f0a84b28397d5ba3d0095807a6bf97b2ecc892 /lib/sqlalchemy/dialects/postgresql/base.py
parent348afaf742d0df017f9ae0c71c981de0fb967780 (diff)
parentd1005e130558b33fd455be6d994e8e863799a318 (diff)
downloadsqlalchemy-62347923754754f93adbf0c3888208be77f26e70.tar.gz
Merge "Implement DDL visitor for PG ENUM with schema translate support"
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py57
1 files changed, 47 insertions, 10 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 92e48f1f3..2db079799 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1011,6 +1011,7 @@ from ...sql import expression
from ...sql import roles
from ...sql import sqltypes
from ...sql import util as sql_util
+from ...sql.ddl import DDLBase
from ...types import BIGINT
from ...types import BOOLEAN
from ...types import CHAR
@@ -1499,10 +1500,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
if not bind.dialect.supports_native_enum:
return
- if not checkfirst or not bind.dialect.has_type(
- bind, self.name, schema=self.schema
- ):
- bind.execute(CreateEnumType(self))
+ bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
"""Emit ``DROP TYPE`` for this
@@ -1522,10 +1520,49 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
if not bind.dialect.supports_native_enum:
return
- if not checkfirst or bind.dialect.has_type(
- bind, self.name, schema=self.schema
- ):
- bind.execute(DropEnumType(self))
+ bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
+
+ class EnumGenerator(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_create_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return not self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_create_enum(enum):
+ return
+
+ self.connection.execute(CreateEnumType(enum))
+
+ class EnumDropper(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_drop_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_drop_enum(enum):
+ return
+
+ self.connection.execute(DropEnumType(enum))
def _check_for_name_in_memos(self, checkfirst, kw):
"""Look in the 'ddl runner' for 'memos', then
@@ -1551,14 +1588,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
return False
def _on_table_create(self, target, bind, checkfirst=False, **kw):
+
if (
checkfirst
or (
not self.metadata
and not kw.get("_is_metadata_operation", False)
)
- and not self._check_for_name_in_memos(checkfirst, kw)
- ):
+ ) and not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_table_drop(self, target, bind, checkfirst=False, **kw):