diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 96 |
1 files changed, 78 insertions, 18 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 3bd7e62d5..c56cccd8d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1011,6 +1011,7 @@ from ...sql import expression from ...sql import roles from ...sql import sqltypes from ...sql import util as sql_util +from ...sql.ddl import DDLBase from ...types import BIGINT from ...types import BOOLEAN from ...types import CHAR @@ -1299,6 +1300,14 @@ class UUID(sqltypes.TypeEngine): """ self.as_uuid = as_uuid + def coerce_compared_value(self, op, value): + """See :meth:`.TypeEngine.coerce_compared_value` for a description.""" + + if isinstance(value, util.string_types): + return self + else: + return super(UUID, self).coerce_compared_value(op, value) + def bind_processor(self, dialect): if self.as_uuid: @@ -1491,10 +1500,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or not bind.dialect.has_type( - bind, self.name, schema=self.schema - ): - bind.execute(CreateEnumType(self)) + bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=True): """Emit ``DROP TYPE`` for this @@ -1514,10 +1520,49 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or bind.dialect.has_type( - bind, self.name, schema=self.schema - ): - bind.execute(DropEnumType(self)) + bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) + + class EnumGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return not self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_create_enum(enum): + return + + self.connection.execute(CreateEnumType(enum)) + + class EnumDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_drop_enum(enum): + return + + self.connection.execute(DropEnumType(enum)) def _check_for_name_in_memos(self, checkfirst, kw): """Look in the 'ddl runner' for 'memos', then @@ -1543,14 +1588,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): return False def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( checkfirst or ( not self.metadata and not kw.get("_is_metadata_operation", False) ) - and not self._check_for_name_in_memos(checkfirst, kw) - ): + ) and not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) def _on_table_drop(self, target, bind, checkfirst=False, **kw): @@ -2176,6 +2221,17 @@ class PGDDLCompiler(compiler.DDLCompiler): generated.sqltext, include_table=False, literal_binds=True ) + def visit_create_sequence(self, create, **kw): + prefix = None + if create.element.data_type is not None: + prefix = " AS %s" % self.type_compiler.process( + create.element.data_type + ) + + return super(PGDDLCompiler, self).visit_create_sequence( + create, prefix=prefix, **kw + ) + class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TSVECTOR(self, type_, **kw): @@ -2847,7 +2903,11 @@ class PGDialect(default.DefaultDialect): "JOIN pg_namespace n ON n.oid = c.relnamespace " "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name, + dict( + schema=schema + if schema is not None + else self.default_schema_name + ), ) return [name for name, in result] @@ -2962,7 +3022,7 @@ class PGDialect(default.DefaultDialect): .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer)) .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode) ) - c = connection.execute(s, table_oid=table_oid) + c = connection.execute(s, dict(table_oid=table_oid)) rows = c.fetchall() # dictionary with (name, ) if default search path or (schema, name) @@ -3204,7 +3264,7 @@ class PGDialect(default.DefaultDialect): ORDER BY k.ord """ t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode) - c = connection.execute(t, table_oid=table_oid) + c = connection.execute(t, dict(table_oid=table_oid)) cols = [r[0] for r in c.fetchall()] PK_CONS_SQL = """ @@ -3214,7 +3274,7 @@ class PGDialect(default.DefaultDialect): ORDER BY 1 """ t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode) - c = connection.execute(t, table_oid=table_oid) + c = connection.execute(t, dict(table_oid=table_oid)) name = c.scalar() return {"constrained_columns": cols, "name": name} @@ -3262,7 +3322,7 @@ class PGDialect(default.DefaultDialect): t = sql.text(FK_SQL).columns( conname=sqltypes.Unicode, condef=sqltypes.Unicode ) - c = connection.execute(t, table=table_oid) + c = connection.execute(t, dict(table=table_oid)) fkeys = [] for conname, condef, conschema in c.fetchall(): m = re.search(FK_REGEX, condef).groups() @@ -3434,7 +3494,7 @@ class PGDialect(default.DefaultDialect): t = sql.text(IDX_SQL).columns( relname=sqltypes.Unicode, attname=sqltypes.Unicode ) - c = connection.execute(t, table_oid=table_oid) + c = connection.execute(t, dict(table_oid=table_oid)) indexes = defaultdict(lambda: defaultdict(dict)) @@ -3576,7 +3636,7 @@ class PGDialect(default.DefaultDialect): """ t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode) - c = connection.execute(t, table_oid=table_oid) + c = connection.execute(t, dict(table_oid=table_oid)) uniques = defaultdict(lambda: defaultdict(dict)) for row in c.fetchall(): @@ -3627,7 +3687,7 @@ class PGDialect(default.DefaultDialect): cons.contype = 'c' """ - c = connection.execute(sql.text(CHECK_SQL), table_oid=table_oid) + c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid)) ret = [] for name, src in c: |