summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/base.py
diff options
context:
space:
mode:
authorDavid Baumgold <david@davidbaumgold.com>2022-02-11 12:30:24 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-06-21 10:17:40 -0400
commit017fd9ae0645eaf2a0fbdd067d10c721505b018c (patch)
tree80adc525448f11b11bb34d0cf3b1a0e708725542 /lib/sqlalchemy/dialects/postgresql/base.py
parent4e2a89c41b0bb423891767d10bdc3cb1b75eaa5e (diff)
downloadsqlalchemy-017fd9ae0645eaf2a0fbdd067d10c721505b018c.tar.gz
Domain type
Added a new Postgresql :class:`_postgresql.DOMAIN` datatype, which follows the same CREATE TYPE / DROP TYPE behaviors as that of PostgreSQL :class:`_postgresql.ENUM`. Much thanks to David Baumgold for the efforts on this. Fixes: #7316 Closes: #7317 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7317 Pull-request-sha: bc9a82f010e6ca2f70a6e8a7620b748e483c26c3 Change-Id: Id8d7e48843a896de17d20cc466b115b3cc065132
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py303
1 files changed, 238 insertions, 65 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 8402341f6..8fc24c933 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1450,6 +1450,9 @@ from __future__ import annotations
from collections import defaultdict
from functools import lru_cache
import re
+from typing import Any
+from typing import List
+from typing import Optional
from . import array as _array
from . import dml
@@ -1457,30 +1460,34 @@ from . import hstore as _hstore
from . import json as _json
from . import pg_catalog
from . import ranges as _ranges
-from .types import _DECIMAL_TYPES # noqa
-from .types import _FLOAT_TYPES # noqa
-from .types import _INT_TYPES # noqa
-from .types import BIT
-from .types import BYTEA
-from .types import CIDR
-from .types import CreateEnumType # noqa
-from .types import DropEnumType # noqa
-from .types import ENUM
-from .types import INET
-from .types import INTERVAL
-from .types import MACADDR
-from .types import MONEY
-from .types import OID
-from .types import PGBit # noqa
-from .types import PGCidr # noqa
-from .types import PGInet # noqa
-from .types import PGInterval # noqa
-from .types import PGMacAddr # noqa
-from .types import PGUuid
-from .types import REGCLASS
-from .types import TIME
-from .types import TIMESTAMP
-from .types import TSVECTOR
+from .named_types import CreateDomainType as CreateDomainType # noqa: F401
+from .named_types import CreateEnumType as CreateEnumType # noqa: F401
+from .named_types import DOMAIN as DOMAIN # noqa: F401
+from .named_types import DropDomainType as DropDomainType # noqa: F401
+from .named_types import DropEnumType as DropEnumType # noqa: F401
+from .named_types import ENUM as ENUM # noqa: F401
+from .named_types import NamedType as NamedType # noqa: F401
+from .types import _DECIMAL_TYPES # noqa: F401
+from .types import _FLOAT_TYPES # noqa: F401
+from .types import _INT_TYPES # noqa: F401
+from .types import BIT as BIT
+from .types import BYTEA as BYTEA
+from .types import CIDR as CIDR
+from .types import INET as INET
+from .types import INTERVAL as INTERVAL
+from .types import MACADDR as MACADDR
+from .types import MONEY as MONEY
+from .types import OID as OID
+from .types import PGBit as PGBit # noqa: F401
+from .types import PGCidr as PGCidr # noqa: F401
+from .types import PGInet as PGInet # noqa: F401
+from .types import PGInterval as PGInterval # noqa: F401
+from .types import PGMacAddr as PGMacAddr # noqa: F401
+from .types import PGUuid as PGUuid
+from .types import REGCLASS as REGCLASS
+from .types import TIME as TIME
+from .types import TIMESTAMP as TIMESTAMP
+from .types import TSVECTOR as TSVECTOR
from ... import exc
from ... import schema
from ... import select
@@ -1515,6 +1522,7 @@ from ...types import SMALLINT
from ...types import TEXT
from ...types import UUID as UUID
from ...types import VARCHAR
+from ...util.typing import TypedDict
IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I)
@@ -2198,6 +2206,38 @@ class PGDDLCompiler(compiler.DDLCompiler):
return "DROP TYPE %s" % (self.preparer.format_type(type_))
+ def visit_create_domain_type(self, create):
+ domain: DOMAIN = create.element
+
+ options = []
+ if domain.collation is not None:
+ options.append(f"COLLATE {self.preparer.quote(domain.collation)}")
+ if domain.default is not None:
+ default = self.render_default_string(domain.default)
+ options.append(f"DEFAULT {default}")
+ if domain.constraint_name is not None:
+ name = self.preparer.truncate_and_render_constraint_name(
+ domain.constraint_name
+ )
+ options.append(f"CONSTRAINT {name}")
+ if domain.not_null:
+ options.append("NOT NULL")
+ if domain.check is not None:
+ check = self.sql_compiler.process(
+ domain.check, include_table=False, literal_binds=True
+ )
+ options.append(f"CHECK ({check})")
+
+ return (
+ f"CREATE DOMAIN {self.preparer.format_type(domain)} AS "
+ f"{self.type_compiler.process(domain.data_type)} "
+ f"{' '.join(options)}"
+ )
+
+ def visit_drop_domain_type(self, drop):
+ domain = drop.element
+ return f"DROP DOMAIN {self.preparer.format_type(domain)}"
+
def visit_create_index(self, create):
preparer = self.preparer
index = create.element
@@ -2470,6 +2510,11 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
identifier_preparer = self.dialect.identifier_preparer
return identifier_preparer.format_type(type_)
+ def visit_DOMAIN(self, type_, identifier_preparer=None, **kw):
+ if identifier_preparer is None:
+ identifier_preparer = self.dialect.identifier_preparer
+ return identifier_preparer.format_type(type_)
+
def visit_TIMESTAMP(self, type_, **kw):
return "TIMESTAMP%s %s" % (
"(%d)" % type_.precision
@@ -2548,7 +2593,9 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
def format_type(self, type_, use_schema=True):
if not type_.name:
- raise exc.CompileError("PostgreSQL ENUM type requires a name.")
+ raise exc.CompileError(
+ f"PostgreSQL {type_.__class__.__name__} type requires a name."
+ )
name = self.quote(type_.name)
effective_schema = self.schema_for_object(type_)
@@ -2558,14 +2605,60 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
and use_schema
and effective_schema is not None
):
- name = self.quote_schema(effective_schema) + "." + name
+ name = f"{self.quote_schema(effective_schema)}.{name}"
return name
+class ReflectedNamedType(TypedDict):
+ """Represents a reflected named type."""
+
+ name: str
+ """Name of the type."""
+ schema: str
+ """The schema of the type."""
+ visible: bool
+ """Indicates if this type is in the current search path."""
+
+
+class ReflectedDomainConstraint(TypedDict):
+ """Represents a reflect check constraint of a domain."""
+
+ name: str
+ """Name of the constraint."""
+ check: str
+ """The check constraint text."""
+
+
+class ReflectedDomain(ReflectedNamedType):
+ """Represents a reflected enum."""
+
+ type: str
+ """The string name of the underlying data type of the domain."""
+ nullable: bool
+ """Indicates if the domain allows null or not."""
+ default: Optional[str]
+ """The string representation of the default value of this domain
+ or ``None`` if none present.
+ """
+ constraints: List[ReflectedDomainConstraint]
+ """The constraints defined in the domain, if any.
+ The constraint are in order of evaluation by postgresql.
+ """
+
+
+class ReflectedEnum(ReflectedNamedType):
+ """Represents a reflected enum."""
+
+ labels: List[str]
+ """The labels that compose the enum."""
+
+
class PGInspector(reflection.Inspector):
dialect: PGDialect
- def get_table_oid(self, table_name, schema=None):
+ def get_table_oid(
+ self, table_name: str, schema: Optional[str] = None
+ ) -> int:
"""Return the OID for the given table name.
:param table_name: string name of the table. For special quoting,
@@ -2582,7 +2675,38 @@ class PGInspector(reflection.Inspector):
conn, table_name, schema, info_cache=self.info_cache
)
- def get_enums(self, schema=None):
+ def get_domains(
+ self, schema: Optional[str] = None
+ ) -> List[ReflectedDomain]:
+ """Return a list of DOMAIN objects.
+
+ Each member is a dictionary containing these fields:
+
+ * name - name of the domain
+ * schema - the schema name for the domain.
+ * visible - boolean, whether or not this domain is visible
+ in the default search path.
+ * type - the type defined by this domain.
+ * nullable - Indicates if this domain can be ``NULL``.
+ * default - The default value of the domain or ``None`` if the
+ domain has no default.
+ * constraints - A list of dict wit the constraint defined by this
+ domain. Each element constaints two keys: ``name`` of the
+ constraint and ``check`` with the constraint text.
+
+ :param schema: schema name. If None, the default schema
+ (typically 'public') is used. May also be set to ``'*'`` to
+ indicate load domains for all schemas.
+
+ .. versionadded:: 2.0
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect._load_domains(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ def get_enums(self, schema: Optional[str] = None) -> List[ReflectedEnum]:
"""Return a list of ENUM objects.
Each member is a dictionary containing these fields:
@@ -2594,7 +2718,7 @@ class PGInspector(reflection.Inspector):
* labels - a list of string labels that apply to the enum.
:param schema: schema name. If None, the default schema
- (typically 'public') is used. May also be set to '*' to
+ (typically 'public') is used. May also be set to ``'*'`` to
indicate load enums for all schemas.
.. versionadded:: 1.0.0
@@ -2605,7 +2729,9 @@ class PGInspector(reflection.Inspector):
conn, schema, info_cache=self.info_cache
)
- def get_foreign_table_names(self, schema=None):
+ def get_foreign_table_names(
+ self, schema: Optional[str] = None
+ ) -> List[str]:
"""Return a list of FOREIGN TABLE names.
Behavior is similar to that of
@@ -2621,13 +2747,15 @@ class PGInspector(reflection.Inspector):
conn, schema, info_cache=self.info_cache
)
- def has_type(self, type_name, schema=None, **kw):
+ def has_type(
+ self, type_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> bool:
"""Return if the database has the specified type in the provided
schema.
:param type_name: the type to check.
:param schema: schema name. If None, the default schema
- (typically 'public') is used. May also be set to '*' to
+ (typically 'public') is used. May also be set to ``'*'`` to
check in all schemas.
.. versionadded:: 2.0
@@ -2941,10 +3069,12 @@ class PGDialect(default.DefaultDialect):
pg_catalog.pg_namespace,
pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace,
)
+
if scope is ObjectScope.DEFAULT:
query = query.where(pg_class_table.c.relpersistence != "t")
elif scope is ObjectScope.TEMPORARY:
query = query.where(pg_class_table.c.relpersistence == "t")
+
if schema is None:
query = query.where(
pg_catalog.pg_table_is_visible(pg_class_table.c.oid),
@@ -3319,9 +3449,12 @@ class PGDialect(default.DefaultDialect):
# dictionary with (name, ) if default search path or (schema, name)
# as keys
- domains = self._load_domains(
- connection, info_cache=kw.get("info_cache")
- )
+ domains = {
+ ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d
+ for d in self._load_domains(
+ connection, schema="*", info_cache=kw.get("info_cache")
+ )
+ }
# dictionary with (name, ) if default search path or (schema, name)
# as keys
@@ -3446,7 +3579,7 @@ class PGDialect(default.DefaultDialect):
break
elif enum_or_domain_key in domains:
domain = domains[enum_or_domain_key]
- attype = domain["attype"]
+ attype = domain["type"]
attype, is_array = _handle_array_type(attype)
# strip quotes from case sensitive enum or domain names
enum_or_domain_key = tuple(
@@ -3736,7 +3869,7 @@ class PGDialect(default.DefaultDialect):
@util.memoized_property
def _fk_regex_pattern(self):
- # https://www.postgresql.org/docs/14.0/static/sql-createtable.html
+ # https://www.postgresql.org/docs/current/static/sql-createtable.html
return re.compile(
r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
@@ -4201,7 +4334,7 @@ class PGDialect(default.DefaultDialect):
(
pg_catalog.pg_constraint.c.oid.is_not(None),
pg_catalog.pg_get_constraintdef(
- pg_catalog.pg_constraint.c.oid
+ pg_catalog.pg_constraint.c.oid, True
),
),
else_=None,
@@ -4265,6 +4398,17 @@ class PGDialect(default.DefaultDialect):
check_constraints[(schema, table_name)].append(entry)
return check_constraints.items()
+ def _pg_type_filter_schema(self, query, schema):
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
@lru_cache()
def _enum_query(self, schema):
lbl_sq = (
@@ -4310,15 +4454,7 @@ class PGDialect(default.DefaultDialect):
)
)
- if schema is None:
- query = query.where(
- pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
- # ignore pg_catalog schema
- pg_catalog.pg_namespace.c.nspname != "pg_catalog",
- )
- elif schema != "*":
- query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
- return query
+ return self._pg_type_filter_schema(query, schema)
@reflection.cache
def _load_enums(self, connection, schema=None, **kw):
@@ -4339,9 +4475,27 @@ class PGDialect(default.DefaultDialect):
)
return enums
- @util.memoized_property
- def _domain_query(self):
- return (
+ @lru_cache()
+ def _domain_query(self, schema):
+ con_sq = (
+ select(
+ pg_catalog.pg_constraint.c.contypid,
+ sql.func.array_agg(
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid, True
+ )
+ ).label("condefs"),
+ sql.func.array_agg(pg_catalog.pg_constraint.c.conname).label(
+ "connames"
+ ),
+ )
+ # The domain this constraint is on; zero if not a domain constraint
+ .where(pg_catalog.pg_constraint.c.contypid != 0)
+ .group_by(pg_catalog.pg_constraint.c.contypid)
+ .subquery("domain_constraints")
+ )
+
+ query = (
select(
pg_catalog.pg_type.c.typname.label("name"),
pg_catalog.format_type(
@@ -4354,38 +4508,57 @@ class PGDialect(default.DefaultDialect):
"visible"
),
pg_catalog.pg_namespace.c.nspname.label("schema"),
+ con_sq.c.condefs,
+ con_sq.c.connames,
)
.join(
pg_catalog.pg_namespace,
pg_catalog.pg_namespace.c.oid
== pg_catalog.pg_type.c.typnamespace,
)
+ .outerjoin(
+ con_sq,
+ pg_catalog.pg_type.c.oid == con_sq.c.contypid,
+ )
.where(pg_catalog.pg_type.c.typtype == "d")
+ .order_by(
+ pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname
+ )
)
+ return self._pg_type_filter_schema(query, schema)
@reflection.cache
- def _load_domains(self, connection, **kw):
+ def _load_domains(self, connection, schema=None, **kw):
# Load data types for domains:
- result = connection.execute(self._domain_query)
+ result = connection.execute(self._domain_query(schema))
- domains = {}
+ domains = []
for domain in result.mappings():
- domain = domain
# strip (30) from character varying(30)
attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
- # 'visible' just means whether or not the domain is in a
- # schema that's on the search path -- or not overridden by
- # a schema with higher precedence. If it's not visible,
- # it will be prefixed with the schema-name when it's used.
- if domain["visible"]:
- key = (domain["name"],)
- else:
- key = (domain["schema"], domain["name"])
-
- domains[key] = {
- "attype": attype,
+ constraints = []
+ if domain["connames"]:
+ # When a domain has multiple CHECK constraints, they will
+ # be tested in alphabetical order by name.
+ sorted_constraints = sorted(
+ zip(domain["connames"], domain["condefs"]),
+ key=lambda t: t[0],
+ )
+ for name, def_ in sorted_constraints:
+ # constraint is in the form "CHECK (expression)".
+ # remove "CHECK (" and the tailing ")".
+ check = def_[7:-1]
+ constraints.append({"name": name, "check": check})
+
+ domain_rec = {
+ "name": domain["name"],
+ "schema": domain["schema"],
+ "visible": domain["visible"],
+ "type": attype,
"nullable": domain["nullable"],
"default": domain["default"],
+ "constraints": constraints,
}
+ domains.append(domain_rec)
return domains