diff options
author | David Baumgold <david@davidbaumgold.com> | 2022-02-11 12:30:24 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-06-21 10:17:40 -0400 |
commit | 017fd9ae0645eaf2a0fbdd067d10c721505b018c (patch) | |
tree | 80adc525448f11b11bb34d0cf3b1a0e708725542 /lib/sqlalchemy/dialects/postgresql/base.py | |
parent | 4e2a89c41b0bb423891767d10bdc3cb1b75eaa5e (diff) | |
download | sqlalchemy-017fd9ae0645eaf2a0fbdd067d10c721505b018c.tar.gz |
Domain type
Added a new Postgresql :class:`_postgresql.DOMAIN` datatype, which follows
the same CREATE TYPE / DROP TYPE behaviors as that of PostgreSQL
:class:`_postgresql.ENUM`. Much thanks to David Baumgold for the efforts on
this.
Fixes: #7316
Closes: #7317
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7317
Pull-request-sha: bc9a82f010e6ca2f70a6e8a7620b748e483c26c3
Change-Id: Id8d7e48843a896de17d20cc466b115b3cc065132
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 303 |
1 files changed, 238 insertions, 65 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 8402341f6..8fc24c933 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1450,6 +1450,9 @@ from __future__ import annotations from collections import defaultdict from functools import lru_cache import re +from typing import Any +from typing import List +from typing import Optional from . import array as _array from . import dml @@ -1457,30 +1460,34 @@ from . import hstore as _hstore from . import json as _json from . import pg_catalog from . import ranges as _ranges -from .types import _DECIMAL_TYPES # noqa -from .types import _FLOAT_TYPES # noqa -from .types import _INT_TYPES # noqa -from .types import BIT -from .types import BYTEA -from .types import CIDR -from .types import CreateEnumType # noqa -from .types import DropEnumType # noqa -from .types import ENUM -from .types import INET -from .types import INTERVAL -from .types import MACADDR -from .types import MONEY -from .types import OID -from .types import PGBit # noqa -from .types import PGCidr # noqa -from .types import PGInet # noqa -from .types import PGInterval # noqa -from .types import PGMacAddr # noqa -from .types import PGUuid -from .types import REGCLASS -from .types import TIME -from .types import TIMESTAMP -from .types import TSVECTOR +from .named_types import CreateDomainType as CreateDomainType # noqa: F401 +from .named_types import CreateEnumType as CreateEnumType # noqa: F401 +from .named_types import DOMAIN as DOMAIN # noqa: F401 +from .named_types import DropDomainType as DropDomainType # noqa: F401 +from .named_types import DropEnumType as DropEnumType # noqa: F401 +from .named_types import ENUM as ENUM # noqa: F401 +from .named_types import NamedType as NamedType # noqa: F401 +from .types import _DECIMAL_TYPES # noqa: F401 +from .types import _FLOAT_TYPES # noqa: F401 +from .types import _INT_TYPES # noqa: F401 +from .types import BIT as BIT +from .types import BYTEA as BYTEA +from .types import CIDR as CIDR +from .types import INET as INET +from .types import INTERVAL as INTERVAL +from .types import MACADDR as MACADDR +from .types import MONEY as MONEY +from .types import OID as OID +from .types import PGBit as PGBit # noqa: F401 +from .types import PGCidr as PGCidr # noqa: F401 +from .types import PGInet as PGInet # noqa: F401 +from .types import PGInterval as PGInterval # noqa: F401 +from .types import PGMacAddr as PGMacAddr # noqa: F401 +from .types import PGUuid as PGUuid +from .types import REGCLASS as REGCLASS +from .types import TIME as TIME +from .types import TIMESTAMP as TIMESTAMP +from .types import TSVECTOR as TSVECTOR from ... import exc from ... import schema from ... import select @@ -1515,6 +1522,7 @@ from ...types import SMALLINT from ...types import TEXT from ...types import UUID as UUID from ...types import VARCHAR +from ...util.typing import TypedDict IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I) @@ -2198,6 +2206,38 @@ class PGDDLCompiler(compiler.DDLCompiler): return "DROP TYPE %s" % (self.preparer.format_type(type_)) + def visit_create_domain_type(self, create): + domain: DOMAIN = create.element + + options = [] + if domain.collation is not None: + options.append(f"COLLATE {self.preparer.quote(domain.collation)}") + if domain.default is not None: + default = self.render_default_string(domain.default) + options.append(f"DEFAULT {default}") + if domain.constraint_name is not None: + name = self.preparer.truncate_and_render_constraint_name( + domain.constraint_name + ) + options.append(f"CONSTRAINT {name}") + if domain.not_null: + options.append("NOT NULL") + if domain.check is not None: + check = self.sql_compiler.process( + domain.check, include_table=False, literal_binds=True + ) + options.append(f"CHECK ({check})") + + return ( + f"CREATE DOMAIN {self.preparer.format_type(domain)} AS " + f"{self.type_compiler.process(domain.data_type)} " + f"{' '.join(options)}" + ) + + def visit_drop_domain_type(self, drop): + domain = drop.element + return f"DROP DOMAIN {self.preparer.format_type(domain)}" + def visit_create_index(self, create): preparer = self.preparer index = create.element @@ -2470,6 +2510,11 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): identifier_preparer = self.dialect.identifier_preparer return identifier_preparer.format_type(type_) + def visit_DOMAIN(self, type_, identifier_preparer=None, **kw): + if identifier_preparer is None: + identifier_preparer = self.dialect.identifier_preparer + return identifier_preparer.format_type(type_) + def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( "(%d)" % type_.precision @@ -2548,7 +2593,9 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): def format_type(self, type_, use_schema=True): if not type_.name: - raise exc.CompileError("PostgreSQL ENUM type requires a name.") + raise exc.CompileError( + f"PostgreSQL {type_.__class__.__name__} type requires a name." + ) name = self.quote(type_.name) effective_schema = self.schema_for_object(type_) @@ -2558,14 +2605,60 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): and use_schema and effective_schema is not None ): - name = self.quote_schema(effective_schema) + "." + name + name = f"{self.quote_schema(effective_schema)}.{name}" return name +class ReflectedNamedType(TypedDict): + """Represents a reflected named type.""" + + name: str + """Name of the type.""" + schema: str + """The schema of the type.""" + visible: bool + """Indicates if this type is in the current search path.""" + + +class ReflectedDomainConstraint(TypedDict): + """Represents a reflect check constraint of a domain.""" + + name: str + """Name of the constraint.""" + check: str + """The check constraint text.""" + + +class ReflectedDomain(ReflectedNamedType): + """Represents a reflected enum.""" + + type: str + """The string name of the underlying data type of the domain.""" + nullable: bool + """Indicates if the domain allows null or not.""" + default: Optional[str] + """The string representation of the default value of this domain + or ``None`` if none present. + """ + constraints: List[ReflectedDomainConstraint] + """The constraints defined in the domain, if any. + The constraint are in order of evaluation by postgresql. + """ + + +class ReflectedEnum(ReflectedNamedType): + """Represents a reflected enum.""" + + labels: List[str] + """The labels that compose the enum.""" + + class PGInspector(reflection.Inspector): dialect: PGDialect - def get_table_oid(self, table_name, schema=None): + def get_table_oid( + self, table_name: str, schema: Optional[str] = None + ) -> int: """Return the OID for the given table name. :param table_name: string name of the table. For special quoting, @@ -2582,7 +2675,38 @@ class PGInspector(reflection.Inspector): conn, table_name, schema, info_cache=self.info_cache ) - def get_enums(self, schema=None): + def get_domains( + self, schema: Optional[str] = None + ) -> List[ReflectedDomain]: + """Return a list of DOMAIN objects. + + Each member is a dictionary containing these fields: + + * name - name of the domain + * schema - the schema name for the domain. + * visible - boolean, whether or not this domain is visible + in the default search path. + * type - the type defined by this domain. + * nullable - Indicates if this domain can be ``NULL``. + * default - The default value of the domain or ``None`` if the + domain has no default. + * constraints - A list of dict wit the constraint defined by this + domain. Each element constaints two keys: ``name`` of the + constraint and ``check`` with the constraint text. + + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to ``'*'`` to + indicate load domains for all schemas. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect._load_domains( + conn, schema, info_cache=self.info_cache + ) + + def get_enums(self, schema: Optional[str] = None) -> List[ReflectedEnum]: """Return a list of ENUM objects. Each member is a dictionary containing these fields: @@ -2594,7 +2718,7 @@ class PGInspector(reflection.Inspector): * labels - a list of string labels that apply to the enum. :param schema: schema name. If None, the default schema - (typically 'public') is used. May also be set to '*' to + (typically 'public') is used. May also be set to ``'*'`` to indicate load enums for all schemas. .. versionadded:: 1.0.0 @@ -2605,7 +2729,9 @@ class PGInspector(reflection.Inspector): conn, schema, info_cache=self.info_cache ) - def get_foreign_table_names(self, schema=None): + def get_foreign_table_names( + self, schema: Optional[str] = None + ) -> List[str]: """Return a list of FOREIGN TABLE names. Behavior is similar to that of @@ -2621,13 +2747,15 @@ class PGInspector(reflection.Inspector): conn, schema, info_cache=self.info_cache ) - def has_type(self, type_name, schema=None, **kw): + def has_type( + self, type_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: """Return if the database has the specified type in the provided schema. :param type_name: the type to check. :param schema: schema name. If None, the default schema - (typically 'public') is used. May also be set to '*' to + (typically 'public') is used. May also be set to ``'*'`` to check in all schemas. .. versionadded:: 2.0 @@ -2941,10 +3069,12 @@ class PGDialect(default.DefaultDialect): pg_catalog.pg_namespace, pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace, ) + if scope is ObjectScope.DEFAULT: query = query.where(pg_class_table.c.relpersistence != "t") elif scope is ObjectScope.TEMPORARY: query = query.where(pg_class_table.c.relpersistence == "t") + if schema is None: query = query.where( pg_catalog.pg_table_is_visible(pg_class_table.c.oid), @@ -3319,9 +3449,12 @@ class PGDialect(default.DefaultDialect): # dictionary with (name, ) if default search path or (schema, name) # as keys - domains = self._load_domains( - connection, info_cache=kw.get("info_cache") - ) + domains = { + ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d + for d in self._load_domains( + connection, schema="*", info_cache=kw.get("info_cache") + ) + } # dictionary with (name, ) if default search path or (schema, name) # as keys @@ -3446,7 +3579,7 @@ class PGDialect(default.DefaultDialect): break elif enum_or_domain_key in domains: domain = domains[enum_or_domain_key] - attype = domain["attype"] + attype = domain["type"] attype, is_array = _handle_array_type(attype) # strip quotes from case sensitive enum or domain names enum_or_domain_key = tuple( @@ -3736,7 +3869,7 @@ class PGDialect(default.DefaultDialect): @util.memoized_property def _fk_regex_pattern(self): - # https://www.postgresql.org/docs/14.0/static/sql-createtable.html + # https://www.postgresql.org/docs/current/static/sql-createtable.html return re.compile( r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" @@ -4201,7 +4334,7 @@ class PGDialect(default.DefaultDialect): ( pg_catalog.pg_constraint.c.oid.is_not(None), pg_catalog.pg_get_constraintdef( - pg_catalog.pg_constraint.c.oid + pg_catalog.pg_constraint.c.oid, True ), ), else_=None, @@ -4265,6 +4398,17 @@ class PGDialect(default.DefaultDialect): check_constraints[(schema, table_name)].append(entry) return check_constraints.items() + def _pg_type_filter_schema(self, query, schema): + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + @lru_cache() def _enum_query(self, schema): lbl_sq = ( @@ -4310,15 +4454,7 @@ class PGDialect(default.DefaultDialect): ) ) - if schema is None: - query = query.where( - pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), - # ignore pg_catalog schema - pg_catalog.pg_namespace.c.nspname != "pg_catalog", - ) - elif schema != "*": - query = query.where(pg_catalog.pg_namespace.c.nspname == schema) - return query + return self._pg_type_filter_schema(query, schema) @reflection.cache def _load_enums(self, connection, schema=None, **kw): @@ -4339,9 +4475,27 @@ class PGDialect(default.DefaultDialect): ) return enums - @util.memoized_property - def _domain_query(self): - return ( + @lru_cache() + def _domain_query(self, schema): + con_sq = ( + select( + pg_catalog.pg_constraint.c.contypid, + sql.func.array_agg( + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ) + ).label("condefs"), + sql.func.array_agg(pg_catalog.pg_constraint.c.conname).label( + "connames" + ), + ) + # The domain this constraint is on; zero if not a domain constraint + .where(pg_catalog.pg_constraint.c.contypid != 0) + .group_by(pg_catalog.pg_constraint.c.contypid) + .subquery("domain_constraints") + ) + + query = ( select( pg_catalog.pg_type.c.typname.label("name"), pg_catalog.format_type( @@ -4354,38 +4508,57 @@ class PGDialect(default.DefaultDialect): "visible" ), pg_catalog.pg_namespace.c.nspname.label("schema"), + con_sq.c.condefs, + con_sq.c.connames, ) .join( pg_catalog.pg_namespace, pg_catalog.pg_namespace.c.oid == pg_catalog.pg_type.c.typnamespace, ) + .outerjoin( + con_sq, + pg_catalog.pg_type.c.oid == con_sq.c.contypid, + ) .where(pg_catalog.pg_type.c.typtype == "d") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) ) + return self._pg_type_filter_schema(query, schema) @reflection.cache - def _load_domains(self, connection, **kw): + def _load_domains(self, connection, schema=None, **kw): # Load data types for domains: - result = connection.execute(self._domain_query) + result = connection.execute(self._domain_query(schema)) - domains = {} + domains = [] for domain in result.mappings(): - domain = domain # strip (30) from character varying(30) attype = re.search(r"([^\(]+)", domain["attype"]).group(1) - # 'visible' just means whether or not the domain is in a - # schema that's on the search path -- or not overridden by - # a schema with higher precedence. If it's not visible, - # it will be prefixed with the schema-name when it's used. - if domain["visible"]: - key = (domain["name"],) - else: - key = (domain["schema"], domain["name"]) - - domains[key] = { - "attype": attype, + constraints = [] + if domain["connames"]: + # When a domain has multiple CHECK constraints, they will + # be tested in alphabetical order by name. + sorted_constraints = sorted( + zip(domain["connames"], domain["condefs"]), + key=lambda t: t[0], + ) + for name, def_ in sorted_constraints: + # constraint is in the form "CHECK (expression)". + # remove "CHECK (" and the tailing ")". + check = def_[7:-1] + constraints.append({"name": name, "check": check}) + + domain_rec = { + "name": domain["name"], + "schema": domain["schema"], + "visible": domain["visible"], + "type": attype, "nullable": domain["nullable"], "default": domain["default"], + "constraints": constraints, } + domains.append(domain_rec) return domains |