summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py96
1 files changed, 78 insertions, 18 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 3bd7e62d5..c56cccd8d 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1011,6 +1011,7 @@ from ...sql import expression
from ...sql import roles
from ...sql import sqltypes
from ...sql import util as sql_util
+from ...sql.ddl import DDLBase
from ...types import BIGINT
from ...types import BOOLEAN
from ...types import CHAR
@@ -1299,6 +1300,14 @@ class UUID(sqltypes.TypeEngine):
"""
self.as_uuid = as_uuid
+ def coerce_compared_value(self, op, value):
+ """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
+
+ if isinstance(value, util.string_types):
+ return self
+ else:
+ return super(UUID, self).coerce_compared_value(op, value)
+
def bind_processor(self, dialect):
if self.as_uuid:
@@ -1491,10 +1500,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
if not bind.dialect.supports_native_enum:
return
- if not checkfirst or not bind.dialect.has_type(
- bind, self.name, schema=self.schema
- ):
- bind.execute(CreateEnumType(self))
+ bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
"""Emit ``DROP TYPE`` for this
@@ -1514,10 +1520,49 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
if not bind.dialect.supports_native_enum:
return
- if not checkfirst or bind.dialect.has_type(
- bind, self.name, schema=self.schema
- ):
- bind.execute(DropEnumType(self))
+ bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
+
+ class EnumGenerator(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_create_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return not self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_create_enum(enum):
+ return
+
+ self.connection.execute(CreateEnumType(enum))
+
+ class EnumDropper(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_drop_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_drop_enum(enum):
+ return
+
+ self.connection.execute(DropEnumType(enum))
def _check_for_name_in_memos(self, checkfirst, kw):
"""Look in the 'ddl runner' for 'memos', then
@@ -1543,14 +1588,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
return False
def _on_table_create(self, target, bind, checkfirst=False, **kw):
+
if (
checkfirst
or (
not self.metadata
and not kw.get("_is_metadata_operation", False)
)
- and not self._check_for_name_in_memos(checkfirst, kw)
- ):
+ ) and not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
@@ -2176,6 +2221,17 @@ class PGDDLCompiler(compiler.DDLCompiler):
generated.sqltext, include_table=False, literal_binds=True
)
+ def visit_create_sequence(self, create, **kw):
+ prefix = None
+ if create.element.data_type is not None:
+ prefix = " AS %s" % self.type_compiler.process(
+ create.element.data_type
+ )
+
+ return super(PGDDLCompiler, self).visit_create_sequence(
+ create, prefix=prefix, **kw
+ )
+
class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_TSVECTOR(self, type_, **kw):
@@ -2847,7 +2903,11 @@ class PGDialect(default.DefaultDialect):
"JOIN pg_namespace n ON n.oid = c.relnamespace "
"WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
).columns(relname=sqltypes.Unicode),
- schema=schema if schema is not None else self.default_schema_name,
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
)
return [name for name, in result]
@@ -2962,7 +3022,7 @@ class PGDialect(default.DefaultDialect):
.bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
.columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
)
- c = connection.execute(s, table_oid=table_oid)
+ c = connection.execute(s, dict(table_oid=table_oid))
rows = c.fetchall()
# dictionary with (name, ) if default search path or (schema, name)
@@ -3204,7 +3264,7 @@ class PGDialect(default.DefaultDialect):
ORDER BY k.ord
"""
t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode)
- c = connection.execute(t, table_oid=table_oid)
+ c = connection.execute(t, dict(table_oid=table_oid))
cols = [r[0] for r in c.fetchall()]
PK_CONS_SQL = """
@@ -3214,7 +3274,7 @@ class PGDialect(default.DefaultDialect):
ORDER BY 1
"""
t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode)
- c = connection.execute(t, table_oid=table_oid)
+ c = connection.execute(t, dict(table_oid=table_oid))
name = c.scalar()
return {"constrained_columns": cols, "name": name}
@@ -3262,7 +3322,7 @@ class PGDialect(default.DefaultDialect):
t = sql.text(FK_SQL).columns(
conname=sqltypes.Unicode, condef=sqltypes.Unicode
)
- c = connection.execute(t, table=table_oid)
+ c = connection.execute(t, dict(table=table_oid))
fkeys = []
for conname, condef, conschema in c.fetchall():
m = re.search(FK_REGEX, condef).groups()
@@ -3434,7 +3494,7 @@ class PGDialect(default.DefaultDialect):
t = sql.text(IDX_SQL).columns(
relname=sqltypes.Unicode, attname=sqltypes.Unicode
)
- c = connection.execute(t, table_oid=table_oid)
+ c = connection.execute(t, dict(table_oid=table_oid))
indexes = defaultdict(lambda: defaultdict(dict))
@@ -3576,7 +3636,7 @@ class PGDialect(default.DefaultDialect):
"""
t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode)
- c = connection.execute(t, table_oid=table_oid)
+ c = connection.execute(t, dict(table_oid=table_oid))
uniques = defaultdict(lambda: defaultdict(dict))
for row in c.fetchall():
@@ -3627,7 +3687,7 @@ class PGDialect(default.DefaultDialect):
cons.contype = 'c'
"""
- c = connection.execute(sql.text(CHECK_SQL), table_oid=table_oid)
+ c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid))
ret = []
for name, src in c: