diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 2624 |
1 files changed, 1246 insertions, 1378 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 83e46151f..36de76e0d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -11,7 +11,7 @@ r""" :name: PostgreSQL :full_support: 9.6, 10, 11, 12, 13, 14 :normal_support: 9.6+ - :best_effort: 8+ + :best_effort: 9+ .. _postgresql_sequences: @@ -1448,23 +1448,52 @@ E.g.:: from __future__ import annotations from collections import defaultdict -import datetime as dt +from functools import lru_cache import re -from typing import Any from . import array as _array from . import dml from . import hstore as _hstore from . import json as _json +from . import pg_catalog from . import ranges as _ranges +from .types import _DECIMAL_TYPES # noqa +from .types import _FLOAT_TYPES # noqa +from .types import _INT_TYPES # noqa +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CreateEnumType # noqa +from .types import DropEnumType # noqa +from .types import ENUM +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MONEY +from .types import OID +from .types import PGBit # noqa +from .types import PGCidr # noqa +from .types import PGInet # noqa +from .types import PGInterval # noqa +from .types import PGMacAddr # noqa +from .types import PGUuid +from .types import REGCLASS +from .types import TIME +from .types import TIMESTAMP +from .types import TSVECTOR from ... import exc from ... import schema +from ... import select from ... import sql from ... import util from ...engine import characteristics from ...engine import default from ...engine import interfaces +from ...engine import ObjectKind +from ...engine import ObjectScope from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import bindparam from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -1472,7 +1501,7 @@ from ...sql import expression from ...sql import roles from ...sql import sqltypes from ...sql import util as sql_util -from ...sql.ddl import InvokeDDLBase +from ...sql.visitors import InternalTraversal from ...types import BIGINT from ...types import BOOLEAN from ...types import CHAR @@ -1596,469 +1625,6 @@ RESERVED_WORDS = set( ] ) -_DECIMAL_TYPES = (1231, 1700) -_FLOAT_TYPES = (700, 701, 1021, 1022) -_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) - - -class PGUuid(UUID): - render_bind_cast = True - render_literal_cast = True - - -class BYTEA(sqltypes.LargeBinary[bytes]): - __visit_name__ = "BYTEA" - - -class INET(sqltypes.TypeEngine[str]): - __visit_name__ = "INET" - - -PGInet = INET - - -class CIDR(sqltypes.TypeEngine[str]): - __visit_name__ = "CIDR" - - -PGCidr = CIDR - - -class MACADDR(sqltypes.TypeEngine[str]): - __visit_name__ = "MACADDR" - - -PGMacAddr = MACADDR - - -class MONEY(sqltypes.TypeEngine[str]): - - r"""Provide the PostgreSQL MONEY type. - - Depending on driver, result rows using this type may return a - string value which includes currency symbols. - - For this reason, it may be preferable to provide conversion to a - numerically-based currency datatype using :class:`_types.TypeDecorator`:: - - import re - import decimal - from sqlalchemy import TypeDecorator - - class NumericMoney(TypeDecorator): - impl = MONEY - - def process_result_value(self, value: Any, dialect: Any) -> None: - if value is not None: - # adjust this for the currency and numeric - m = re.match(r"\$([\d.]+)", value) - if m: - value = decimal.Decimal(m.group(1)) - return value - - Alternatively, the conversion may be applied as a CAST using - the :meth:`_types.TypeDecorator.column_expression` method as follows:: - - import decimal - from sqlalchemy import cast - from sqlalchemy import TypeDecorator - - class NumericMoney(TypeDecorator): - impl = MONEY - - def column_expression(self, column: Any): - return cast(column, Numeric()) - - .. versionadded:: 1.2 - - """ - - __visit_name__ = "MONEY" - - -class OID(sqltypes.TypeEngine[int]): - - """Provide the PostgreSQL OID type. - - .. versionadded:: 0.9.5 - - """ - - __visit_name__ = "OID" - - -class REGCLASS(sqltypes.TypeEngine[str]): - - """Provide the PostgreSQL REGCLASS type. - - .. versionadded:: 1.2.7 - - """ - - __visit_name__ = "REGCLASS" - - -class TIMESTAMP(sqltypes.TIMESTAMP): - def __init__(self, timezone=False, precision=None): - super(TIMESTAMP, self).__init__(timezone=timezone) - self.precision = precision - - -class TIME(sqltypes.TIME): - def __init__(self, timezone=False, precision=None): - super(TIME, self).__init__(timezone=timezone) - self.precision = precision - - -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): - - """PostgreSQL INTERVAL type.""" - - __visit_name__ = "INTERVAL" - native = True - - def __init__(self, precision=None, fields=None): - """Construct an INTERVAL. - - :param precision: optional integer precision value - :param fields: string fields specifier. allows storage of fields - to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, - etc. - - .. versionadded:: 1.2 - - """ - self.precision = precision - self.fields = fields - - @classmethod - def adapt_emulated_to_native(cls, interval, **kw): - return INTERVAL(precision=interval.second_precision) - - @property - def _type_affinity(self): - return sqltypes.Interval - - def as_generic(self, allow_nulltype=False): - return sqltypes.Interval(native=True, second_precision=self.precision) - - @property - def python_type(self): - return dt.timedelta - - -PGInterval = INTERVAL - - -class BIT(sqltypes.TypeEngine[int]): - __visit_name__ = "BIT" - - def __init__(self, length=None, varying=False): - if not varying: - # BIT without VARYING defaults to length 1 - self.length = length or 1 - else: - # but BIT VARYING can be unlimited-length, so no default - self.length = length - self.varying = varying - - -PGBit = BIT - - -class TSVECTOR(sqltypes.TypeEngine[Any]): - - """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL - text search type TSVECTOR. - - It can be used to do full text queries on natural language - documents. - - .. versionadded:: 0.9.0 - - .. seealso:: - - :ref:`postgresql_match` - - """ - - __visit_name__ = "TSVECTOR" - - -class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): - - """PostgreSQL ENUM type. - - This is a subclass of :class:`_types.Enum` which includes - support for PG's ``CREATE TYPE`` and ``DROP TYPE``. - - When the builtin type :class:`_types.Enum` is used and the - :paramref:`.Enum.native_enum` flag is left at its default of - True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` - type as the implementation, so the special create/drop rules - will be used. - - The create/drop behavior of ENUM is necessarily intricate, due to the - awkward relationship the ENUM type has in relationship to the - parent table, in that it may be "owned" by just a single table, or - may be shared among many tables. - - When using :class:`_types.Enum` or :class:`_postgresql.ENUM` - in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted - corresponding to when the :meth:`_schema.Table.create` and - :meth:`_schema.Table.drop` - methods are called:: - - table = Table('sometable', metadata, - Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) - ) - - table.create(engine) # will emit CREATE ENUM and CREATE TABLE - table.drop(engine) # will emit DROP TABLE and DROP ENUM - - To use a common enumerated type between multiple tables, the best - practice is to declare the :class:`_types.Enum` or - :class:`_postgresql.ENUM` independently, and associate it with the - :class:`_schema.MetaData` object itself:: - - my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) - - t1 = Table('sometable_one', metadata, - Column('some_enum', myenum) - ) - - t2 = Table('sometable_two', metadata, - Column('some_enum', myenum) - ) - - When this pattern is used, care must still be taken at the level - of individual table creates. Emitting CREATE TABLE without also - specifying ``checkfirst=True`` will still cause issues:: - - t1.create(engine) # will fail: no such type 'myenum' - - If we specify ``checkfirst=True``, the individual table-level create - operation will check for the ``ENUM`` and create if not exists:: - - # will check if enum exists, and emit CREATE TYPE if not - t1.create(engine, checkfirst=True) - - When using a metadata-level ENUM type, the type will always be created - and dropped if either the metadata-wide create/drop is called:: - - metadata.create_all(engine) # will emit CREATE TYPE - metadata.drop_all(engine) # will emit DROP TYPE - - The type can also be created and dropped directly:: - - my_enum.create(engine) - my_enum.drop(engine) - - .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type - now behaves more strictly with regards to CREATE/DROP. A metadata-level - ENUM type will only be created and dropped at the metadata level, - not the table level, with the exception of - ``table.create(checkfirst=True)``. - The ``table.drop()`` call will now emit a DROP TYPE for a table-level - enumerated type. - - """ - - native_enum = True - - def __init__(self, *enums, **kw): - """Construct an :class:`_postgresql.ENUM`. - - Arguments are the same as that of - :class:`_types.Enum`, but also including - the following parameters. - - :param create_type: Defaults to True. - Indicates that ``CREATE TYPE`` should be - emitted, after optionally checking for the - presence of the type, when the parent - table is being created; and additionally - that ``DROP TYPE`` is called when the table - is dropped. When ``False``, no check - will be performed and no ``CREATE TYPE`` - or ``DROP TYPE`` is emitted, unless - :meth:`~.postgresql.ENUM.create` - or :meth:`~.postgresql.ENUM.drop` - are called directly. - Setting to ``False`` is helpful - when invoking a creation scheme to a SQL file - without access to the actual database - - the :meth:`~.postgresql.ENUM.create` and - :meth:`~.postgresql.ENUM.drop` methods can - be used to emit SQL to a target bind. - - """ - native_enum = kw.pop("native_enum", None) - if native_enum is False: - util.warn( - "the native_enum flag does not apply to the " - "sqlalchemy.dialects.postgresql.ENUM datatype; this type " - "always refers to ENUM. Use sqlalchemy.types.Enum for " - "non-native enum." - ) - self.create_type = kw.pop("create_type", True) - super(ENUM, self).__init__(*enums, **kw) - - @classmethod - def adapt_emulated_to_native(cls, impl, **kw): - """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain - :class:`.Enum`. - - """ - kw.setdefault("validate_strings", impl.validate_strings) - kw.setdefault("name", impl.name) - kw.setdefault("schema", impl.schema) - kw.setdefault("inherit_schema", impl.inherit_schema) - kw.setdefault("metadata", impl.metadata) - kw.setdefault("_create_events", False) - kw.setdefault("values_callable", impl.values_callable) - kw.setdefault("omit_aliases", impl._omit_aliases) - return cls(**kw) - - def create(self, bind=None, checkfirst=True): - """Emit ``CREATE TYPE`` for this - :class:`_postgresql.ENUM`. - - If the underlying dialect does not support - PostgreSQL CREATE TYPE, no action is taken. - - :param bind: a connectable :class:`_engine.Engine`, - :class:`_engine.Connection`, or similar object to emit - SQL. - :param checkfirst: if ``True``, a query against - the PG catalog will be first performed to see - if the type does not exist already before - creating. - - """ - if not bind.dialect.supports_native_enum: - return - - bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) - - def drop(self, bind=None, checkfirst=True): - """Emit ``DROP TYPE`` for this - :class:`_postgresql.ENUM`. - - If the underlying dialect does not support - PostgreSQL DROP TYPE, no action is taken. - - :param bind: a connectable :class:`_engine.Engine`, - :class:`_engine.Connection`, or similar object to emit - SQL. - :param checkfirst: if ``True``, a query against - the PG catalog will be first performed to see - if the type actually exists before dropping. - - """ - if not bind.dialect.supports_native_enum: - return - - bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) - - class EnumGenerator(InvokeDDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_create_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return not self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_create_enum(enum): - return - - self.connection.execute(CreateEnumType(enum)) - - class EnumDropper(InvokeDDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_drop_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_drop_enum(enum): - return - - self.connection.execute(DropEnumType(enum)) - - def get_dbapi_type(self, dbapi): - """dont return dbapi.STRING for ENUM in PostgreSQL, since that's - a different type""" - - return None - - def _check_for_name_in_memos(self, checkfirst, kw): - """Look in the 'ddl runner' for 'memos', then - note our name in that collection. - - This to ensure a particular named enum is operated - upon only once within any kind of create/drop - sequence without relying upon "checkfirst". - - """ - if not self.create_type: - return True - if "_ddl_runner" in kw: - ddl_runner = kw["_ddl_runner"] - if "_pg_enums" in ddl_runner.memo: - pg_enums = ddl_runner.memo["_pg_enums"] - else: - pg_enums = ddl_runner.memo["_pg_enums"] = set() - present = (self.schema, self.name) in pg_enums - pg_enums.add((self.schema, self.name)) - return present - else: - return False - - def _on_table_create(self, target, bind, checkfirst=False, **kw): - if ( - checkfirst - or ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - ) - ) and not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_table_drop(self, target, bind, checkfirst=False, **kw): - if ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - and not self._check_for_name_in_memos(checkfirst, kw) - ): - self.drop(bind=bind, checkfirst=checkfirst) - - def _on_metadata_create(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.drop(bind=bind, checkfirst=checkfirst) - - colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -2997,8 +2563,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): class PGInspector(reflection.Inspector): + dialect: PGDialect + def get_table_oid(self, table_name, schema=None): - """Return the OID for the given table name.""" + """Return the OID for the given table name. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ with self._operation_context() as conn: return self.dialect.get_table_oid( @@ -3023,9 +2600,10 @@ class PGInspector(reflection.Inspector): .. versionadded:: 1.0.0 """ - schema = schema or self.default_schema_name with self._operation_context() as conn: - return self.dialect._load_enums(conn, schema) + return self.dialect._load_enums( + conn, schema, info_cache=self.info_cache + ) def get_foreign_table_names(self, schema=None): """Return a list of FOREIGN TABLE names. @@ -3038,38 +2616,29 @@ class PGInspector(reflection.Inspector): .. versionadded:: 1.0.0 """ - schema = schema or self.default_schema_name with self._operation_context() as conn: - return self.dialect._get_foreign_table_names(conn, schema) - - def get_view_names(self, schema=None, include=("plain", "materialized")): - """Return all view names in `schema`. + return self.dialect._get_foreign_table_names( + conn, schema, info_cache=self.info_cache + ) - :param schema: Optional, retrieve names from a non-default schema. - For special quoting, use :class:`.quoted_name`. + def has_type(self, type_name, schema=None, **kw): + """Return if the database has the specified type in the provided + schema. - :param include: specify which types of views to return. Passed - as a string value (for a single type) or a tuple (for any number - of types). Defaults to ``('plain', 'materialized')``. + :param type_name: the type to check. + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to '*' to + check in all schemas. - .. versionadded:: 1.1 + .. versionadded:: 2.0 """ - with self._operation_context() as conn: - return self.dialect.get_view_names( - conn, schema, info_cache=self.info_cache, include=include + return self.dialect.has_type( + conn, type_name, schema, info_cache=self.info_cache ) -class CreateEnumType(schema._CreateDropBase): - __visit_name__ = "create_enum_type" - - -class DropEnumType(schema._CreateDropBase): - __visit_name__ = "drop_enum_type" - - class PGExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( @@ -3262,35 +2831,14 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) - if self.server_version_info <= (8, 2): - self.delete_returning = ( - self.update_returning - ) = self.insert_returning = False - - self.supports_native_enum = self.server_version_info >= (8, 3) - if not self.supports_native_enum: - self.colspecs = self.colspecs.copy() - # pop base Enum type - self.colspecs.pop(sqltypes.Enum, None) - # psycopg2, others may have placed ENUM here as well - self.colspecs.pop(ENUM, None) - # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 self.supports_smallserial = self.server_version_info >= (9, 2) - if self.server_version_info < (8, 2): - self._backslash_escapes = False - else: - # ensure this query is not emitted on server version < 8.2 - # as it will fail - std_string = connection.exec_driver_sql( - "show standard_conforming_strings" - ).scalar() - self._backslash_escapes = std_string == "off" - - self._supports_create_index_concurrently = ( - self.server_version_info >= (8, 2) - ) + std_string = connection.exec_driver_sql( + "show standard_conforming_strings" + ).scalar() + self._backslash_escapes = std_string == "off" + self._supports_drop_index_concurrently = self.server_version_info >= ( 9, 2, @@ -3370,122 +2918,100 @@ class PGDialect(default.DefaultDialect): self.do_commit(connection.connection) def do_recover_twophase(self, connection): - resultset = connection.execute( + return connection.scalars( sql.text("SELECT gid FROM pg_prepared_xacts") - ) - return [row[0] for row in resultset] + ).all() def _get_default_schema_name(self, connection): return connection.exec_driver_sql("select current_schema()").scalar() - def has_schema(self, connection, schema): - query = ( - "select nspname from pg_namespace " "where lower(nspname)=:schema" - ) - cursor = connection.execute( - sql.text(query).bindparams( - sql.bindparam( - "schema", - str(schema.lower()), - type_=sqltypes.Unicode, - ) - ) + @reflection.cache + def has_schema(self, connection, schema, **kw): + query = select(pg_catalog.pg_namespace.c.nspname).where( + pg_catalog.pg_namespace.c.nspname == schema ) + return bool(connection.scalar(query)) - return bool(cursor.first()) - - def has_table(self, connection, table_name, schema=None): - self._ensure_has_table_connection(connection) - # seems like case gets folded in pg_class... + def _pg_class_filter_scope_schema( + self, query, schema, scope, pg_class_table=None + ): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + query = query.join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace, + ) + if scope is ObjectScope.DEFAULT: + query = query.where(pg_class_table.c.relpersistence != "t") + elif scope is ObjectScope.TEMPORARY: + query = query.where(pg_class_table.c.relpersistence == "t") if schema is None: - cursor = connection.execute( - sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where " - "pg_catalog.pg_table_is_visible(c.oid) " - "and relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(table_name), - type_=sqltypes.Unicode, - ) - ) + query = query.where( + pg_catalog.pg_table_is_visible(pg_class_table.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", ) else: - cursor = connection.execute( - sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=:schema and " - "relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(table_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) - ) - return bool(cursor.first()) - - def has_sequence(self, connection, sequence_name, schema=None): - if schema is None: - schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema and relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(sequence_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + def _pg_class_relkind_condition(self, relkinds, pg_class_table=None): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + # uses the any form instead of in otherwise postgresql complaings + # that 'IN could not convert type character to "char"' + return pg_class_table.c.relkind == sql.any_(_array.array(relkinds)) + + @lru_cache() + def _has_table_query(self, schema): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relname == bindparam("table_name"), + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), + ) + return self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY ) - return bool(cursor.first()) + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + query = self._has_table_query(schema) + return bool(connection.scalar(query, {"table_name": table_name})) - def has_type(self, connection, type_name, schema=None): - if schema is not None: - query = """ - SELECT EXISTS ( - SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n - WHERE t.typnamespace = n.oid - AND t.typname = :typname - AND n.nspname = :nspname - ) - """ - query = sql.text(query) - else: - query = """ - SELECT EXISTS ( - SELECT * FROM pg_catalog.pg_type t - WHERE t.typname = :typname - AND pg_type_is_visible(t.oid) - ) - """ - query = sql.text(query) - query = query.bindparams( - sql.bindparam("typname", str(type_name), type_=sqltypes.Unicode) + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relkind == "S", + pg_catalog.pg_class.c.relname == sequence_name, ) - if schema is not None: - query = query.bindparams( - sql.bindparam("nspname", str(schema), type_=sqltypes.Unicode) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + return bool(connection.scalar(query)) + + @reflection.cache + def has_type(self, connection, type_name, schema=None, **kw): + query = ( + select(pg_catalog.pg_type.c.typname) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, ) - cursor = connection.execute(query) - return bool(cursor.scalar()) + .where(pg_catalog.pg_type.c.typname == type_name) + ) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + + return bool(connection.scalar(query)) def _get_server_version_info(self, connection): v = connection.exec_driver_sql("select pg_catalog.version()").scalar() @@ -3502,229 +3028,300 @@ class PGDialect(default.DefaultDialect): @reflection.cache def get_table_oid(self, connection, table_name, schema=None, **kw): - """Fetch the oid for schema.table_name. - - Several reflection methods require the table oid. The idea for using - this method is that it can be fetched one time and cached for - subsequent calls. - - """ - table_oid = None - if schema is not None: - schema_where_clause = "n.nspname = :schema" - else: - schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - query = ( - """ - SELECT c.oid - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE (%s) - AND c.relname = :table_name AND c.relkind in - ('r', 'v', 'm', 'f', 'p') - """ - % schema_where_clause + """Fetch the oid for schema.table_name.""" + query = select(pg_catalog.pg_class.c.oid).where( + pg_catalog.pg_class.c.relname == table_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), ) - # Since we're binding to unicode, table_name and schema_name must be - # unicode. - table_name = str(table_name) - if schema is not None: - schema = str(schema) - s = sql.text(query).bindparams(table_name=sqltypes.Unicode) - s = s.columns(oid=sqltypes.Integer) - if schema: - s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode)) - c = connection.execute(s, dict(table_name=table_name, schema=schema)) - table_oid = c.scalar() + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + table_oid = connection.scalar(query) if table_oid is None: - raise exc.NoSuchTableError(table_name) + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) return table_oid @reflection.cache def get_schema_names(self, connection, **kw): - result = connection.execute( - sql.text( - "SELECT nspname FROM pg_namespace " - "WHERE nspname NOT LIKE 'pg_%' " - "ORDER BY nspname" - ).columns(nspname=sqltypes.Unicode) + query = ( + select(pg_catalog.pg_namespace.c.nspname) + .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%")) + .order_by(pg_catalog.pg_namespace.c.nspname) + ) + return connection.scalars(query).all() + + def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope): + query = select(pg_catalog.pg_class.c.relname).where( + self._pg_class_relkind_condition(relkinds) ) - return [name for name, in result] + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + return connection.scalars(query).all() @reflection.cache def get_table_names(self, connection, schema=None, **kw): - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + return self._get_relnames_for_relkinds( + connection, + schema=None, + relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.TEMPORARY, ) - return [name for name, in result] @reflection.cache def _get_foreign_table_names(self, connection, schema=None, **kw): - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind = 'f'" - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("f",), scope=ObjectScope.ANY ) - return [name for name, in result] @reflection.cache - def get_view_names( - self, connection, schema=None, include=("plain", "materialized"), **kw - ): + def get_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.DEFAULT, + ) - include_kind = {"plain": "v", "materialized": "m"} - try: - kinds = [include_kind[i] for i in util.to_list(include)] - except KeyError: - raise ValueError( - "include %r unknown, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" % (include,) - ) - if not kinds: - raise ValueError( - "empty include, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" - ) + @reflection.cache + def get_materialized_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_MAT_VIEW, + scope=ObjectScope.DEFAULT, + ) - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind IN (%s)" - % (", ".join("'%s'" % elem for elem in kinds)) - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + @reflection.cache + def get_temp_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + # NOTE: do not include temp materialzied views (that do not + # seem to be a thing at least up to version 14) + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.TEMPORARY, ) - return [name for name, in result] @reflection.cache def get_sequence_names(self, connection, schema=None, **kw): - if not schema: - schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema" - ).bindparams( - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("S",), scope=ObjectScope.ANY ) - return [row[0] for row in cursor] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): - view_def = connection.scalar( - sql.text( - "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relname = :view_name " - "AND c.relkind IN ('v', 'm')" - ).columns(view_def=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name, - view_name=view_name, - ), + query = ( + select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid)) + .select_from(pg_catalog.pg_class) + .where( + pg_catalog.pg_class.c.relname == view_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW + ), + ) ) - return view_def + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + res = connection.scalar(query) + if res is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + else: + return res + + def _value_or_raise(self, data, table, schema): + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + return True, {"filter_names": filter_names} + else: + return False, {} + + def _kind_to_relkinds(self, kind: ObjectKind) -> tuple[str, ...]: + if kind is ObjectKind.ANY: + return pg_catalog.RELKINDS_ALL_TABLE_LIKE + relkinds = () + if ObjectKind.TABLE in kind: + relkinds += pg_catalog.RELKINDS_TABLE + if ObjectKind.VIEW in kind: + relkinds += pg_catalog.RELKINDS_VIEW + if ObjectKind.MATERIALIZED_VIEW in kind: + relkinds += pg_catalog.RELKINDS_MAT_VIEW + return relkinds @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) + @lru_cache() + def _columns_query(self, schema, has_filter_names, scope, kind): + # NOTE: the query with the default and identity options scalar + # subquery is faster than trying to use outer joins for them generated = ( - "a.attgenerated as generated" + pg_catalog.pg_attribute.c.attgenerated.label("generated") if self.server_version_info >= (12,) - else "NULL as generated" + else sql.null().label("generated") ) if self.server_version_info >= (10,): - # a.attidentity != '' is required or it will reflect also - # serial columns as identity. - identity = """\ - (SELECT json_build_object( - 'always', a.attidentity = 'a', - 'start', s.seqstart, - 'increment', s.seqincrement, - 'minvalue', s.seqmin, - 'maxvalue', s.seqmax, - 'cache', s.seqcache, - 'cycle', s.seqcycle) - FROM pg_catalog.pg_sequence s - JOIN pg_catalog.pg_class c on s.seqrelid = c."oid" - WHERE c.relkind = 'S' - AND a.attidentity != '' - AND s.seqrelid = pg_catalog.pg_get_serial_sequence( - a.attrelid::regclass::text, a.attname - )::regclass::oid - ) as identity_options\ - """ + # join lateral performs worse (~2x slower) than a scalar_subquery + identity = ( + select( + sql.func.json_build_object( + "always", + pg_catalog.pg_attribute.c.attidentity == "a", + "start", + pg_catalog.pg_sequence.c.seqstart, + "increment", + pg_catalog.pg_sequence.c.seqincrement, + "minvalue", + pg_catalog.pg_sequence.c.seqmin, + "maxvalue", + pg_catalog.pg_sequence.c.seqmax, + "cache", + pg_catalog.pg_sequence.c.seqcache, + "cycle", + pg_catalog.pg_sequence.c.seqcycle, + ) + ) + .select_from(pg_catalog.pg_sequence) + .where( + # attidentity != '' is required or it will reflect also + # serial columns as identity. + pg_catalog.pg_attribute.c.attidentity != "", + pg_catalog.pg_sequence.c.seqrelid + == sql.cast( + sql.cast( + pg_catalog.pg_get_serial_sequence( + sql.cast( + sql.cast( + pg_catalog.pg_attribute.c.attrelid, + REGCLASS, + ), + TEXT, + ), + pg_catalog.pg_attribute.c.attname, + ), + REGCLASS, + ), + OID, + ), + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("identity_options") + ) else: - identity = "NULL as identity_options" - - SQL_COLS = """ - SELECT a.attname, - pg_catalog.format_type(a.atttypid, a.atttypmod), - ( - SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) - FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum - AND a.atthasdef - ) AS DEFAULT, - a.attnotnull, - a.attrelid as table_oid, - pgd.description as comment, - %s, - %s - FROM pg_catalog.pg_attribute a - LEFT JOIN pg_catalog.pg_description pgd ON ( - pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum) - WHERE a.attrelid = :table_oid - AND a.attnum > 0 AND NOT a.attisdropped - ORDER BY a.attnum - """ % ( - generated, - identity, + identity = sql.null().label("identity_options") + + # join lateral performs the same as scalar_subquery here + default = ( + select( + pg_catalog.pg_get_expr( + pg_catalog.pg_attrdef.c.adbin, + pg_catalog.pg_attrdef.c.adrelid, + ) + ) + .select_from(pg_catalog.pg_attrdef) + .where( + pg_catalog.pg_attrdef.c.adrelid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attrdef.c.adnum + == pg_catalog.pg_attribute.c.attnum, + pg_catalog.pg_attribute.c.atthasdef, + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("default") ) - s = ( - sql.text(SQL_COLS) - .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer)) - .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode) + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_attribute.c.attname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_attribute.c.atttypid, + pg_catalog.pg_attribute.c.atttypmod, + ).label("format_type"), + default, + pg_catalog.pg_attribute.c.attnotnull.label("not_null"), + pg_catalog.pg_class.c.relname.label("table_name"), + pg_catalog.pg_description.c.description.label("comment"), + generated, + identity, + ) + .select_from(pg_catalog.pg_class) + # NOTE: postgresql support table with no user column, meaning + # there is no row with pg_attribute.attnum > 0. use a left outer + # join to avoid filtering these tables. + .outerjoin( + pg_catalog.pg_attribute, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attribute.c.attnum > 0, + ~pg_catalog.pg_attribute.c.attisdropped, + ), + ) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_description.c.objsubid + == pg_catalog.pg_attribute.c.attnum, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + .order_by( + pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum + ) ) - c = connection.execute(s, dict(table_oid=table_oid)) - rows = c.fetchall() + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_columns( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._columns_query(schema, has_filter_names, scope, kind) + rows = connection.execute(query, params).mappings() # dictionary with (name, ) if default search path or (schema, name) # as keys - domains = self._load_domains(connection) + domains = self._load_domains( + connection, info_cache=kw.get("info_cache") + ) # dictionary with (name, ) if default search path or (schema, name) # as keys @@ -3732,257 +3329,340 @@ class PGDialect(default.DefaultDialect): ((rec["name"],), rec) if rec["visible"] else ((rec["schema"], rec["name"]), rec) - for rec in self._load_enums(connection, schema="*") + for rec in self._load_enums( + connection, schema="*", info_cache=kw.get("info_cache") + ) ) - # format columns - columns = [] - - for ( - name, - format_type, - default_, - notnull, - table_oid, - comment, - generated, - identity, - ) in rows: - column_info = self._get_column_info( - name, - format_type, - default_, - notnull, - domains, - enums, - schema, - comment, - generated, - identity, - ) - columns.append(column_info) - return columns + columns = self._get_columns_info(rows, domains, enums, schema) + + return columns.items() + + def _get_columns_info(self, rows, domains, enums, schema): + array_type_pattern = re.compile(r"\[\]$") + attype_pattern = re.compile(r"\(.*\)") + charlen_pattern = re.compile(r"\(([\d,]+)\)") + args_pattern = re.compile(r"\((.*)\)") + args_split_pattern = re.compile(r"\s*,\s*") - def _get_column_info( - self, - name, - format_type, - default, - notnull, - domains, - enums, - schema, - comment, - generated, - identity, - ): def _handle_array_type(attype): return ( # strip '[]' from integer[], etc. - re.sub(r"\[\]$", "", attype), + array_type_pattern.sub("", attype), attype.endswith("[]"), ) - # strip (*) from character varying(5), timestamp(5) - # with time zone, geometry(POLYGON), etc. - attype = re.sub(r"\(.*\)", "", format_type) + columns = defaultdict(list) + for row_dict in rows: + # ensure that each table has an entry, even if it has no columns + if row_dict["name"] is None: + columns[ + (schema, row_dict["table_name"]) + ] = ReflectionDefaults.columns() + continue + table_cols = columns[(schema, row_dict["table_name"])] - # strip '[]' from integer[], etc. and check if an array - attype, is_array = _handle_array_type(attype) + format_type = row_dict["format_type"] + default = row_dict["default"] + name = row_dict["name"] + generated = row_dict["generated"] + identity = row_dict["identity_options"] - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + # strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = attype_pattern.sub("", format_type) - nullable = not notnull + # strip '[]' from integer[], etc. and check if an array + attype, is_array = _handle_array_type(attype) - charlen = re.search(r"\(([\d,]+)\)", format_type) - if charlen: - charlen = charlen.group(1) - args = re.search(r"\((.*)\)", format_type) - if args and args.group(1): - args = tuple(re.split(r"\s*,\s*", args.group(1))) - else: - args = () - kwargs = {} + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + nullable = not row_dict["not_null"] - if attype == "numeric": + charlen = charlen_pattern.search(format_type) if charlen: - prec, scale = charlen.split(",") - args = (int(prec), int(scale)) + charlen = charlen.group(1) + args = args_pattern.search(format_type) + if args and args.group(1): + args = tuple(args_split_pattern.split(args.group(1))) else: args = () - elif attype == "double precision": - args = (53,) - elif attype == "integer": - args = () - elif attype in ("timestamp with time zone", "time with time zone"): - kwargs["timezone"] = True - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype in ( - "timestamp without time zone", - "time without time zone", - "time", - ): - kwargs["timezone"] = False - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype == "bit varying": - kwargs["varying"] = True - if charlen: + kwargs = {} + + if attype == "numeric": + if charlen: + prec, scale = charlen.split(",") + args = (int(prec), int(scale)) + else: + args = () + elif attype == "double precision": + args = (53,) + elif attype == "integer": + args = () + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype == "bit varying": + kwargs["varying"] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) + if charlen: + kwargs["precision"] = int(charlen) + if field_match: + kwargs["fields"] = field_match.group(1) + attype = "interval" + args = () + elif charlen: args = (int(charlen),) + + while True: + # looping here to suit nested domains + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif enum_or_domain_key in enums: + enum = enums[enum_or_domain_key] + coltype = ENUM + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + break + elif enum_or_domain_key in domains: + domain = domains[enum_or_domain_key] + attype = domain["attype"] + attype, is_array = _handle_array_type(attype) + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple( + util.quoted_token_parser(attype) + ) + # A table can't override a not null on the domain, + # but can override nullable + nullable = nullable and domain["nullable"] + if domain["default"] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain["default"] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = self.ischema_names["_array"](coltype) else: - args = () - elif attype.startswith("interval"): - field_match = re.match(r"interval (.+)", attype, re.I) - if charlen: - kwargs["precision"] = int(charlen) - if field_match: - kwargs["fields"] = field_match.group(1) - attype = "interval" - args = () - elif charlen: - args = (int(charlen),) - - while True: - # looping here to suit nested domains - if attype in self.ischema_names: - coltype = self.ischema_names[attype] - break - elif enum_or_domain_key in enums: - enum = enums[enum_or_domain_key] - coltype = ENUM - kwargs["name"] = enum["name"] - if not enum["visible"]: - kwargs["schema"] = enum["schema"] - args = tuple(enum["labels"]) - break - elif enum_or_domain_key in domains: - domain = domains[enum_or_domain_key] - attype = domain["attype"] - attype, is_array = _handle_array_type(attype) - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) - # A table can't override a not null on the domain, - # but can override nullable - nullable = nullable and domain["nullable"] - if domain["default"] and not default: - # It can, however, override the default - # value, but can't set it to null. - default = domain["default"] - continue + util.warn( + "Did not recognize type '%s' of column '%s'" + % (attype, name) + ) + coltype = sqltypes.NULLTYPE + + # If a zero byte or blank string depending on driver (is also + # absent for older PG versions), then not a generated column. + # Otherwise, s = stored. (Other values might be added in the + # future.) + if generated not in (None, "", b"\x00"): + computed = dict( + sqltext=default, persisted=generated in ("s", b"s") + ) + default = None else: - coltype = None - break + computed = None - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = self.ischema_names["_array"](coltype) - else: - util.warn( - "Did not recognize type '%s' of column '%s'" % (attype, name) + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + if issubclass(coltype._type_affinity, sqltypes.Integer): + autoincrement = True + # the default is related to a Sequence + if "." not in match.group(2) and schema is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = ( + match.group(1) + + ('"%s"' % schema) + + "." + + match.group(2) + + match.group(3) + ) + + column_info = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": autoincrement or identity is not None, + "comment": row_dict["comment"], + } + if computed is not None: + column_info["computed"] = computed + if identity is not None: + column_info["identity"] = identity + + table_cols.append(column_info) + + return columns + + @lru_cache() + def _table_oids_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + oid_q = select( + pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname + ).where(self._pg_class_relkind_condition(relkinds)) + oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope) + + if has_filter_names: + oid_q = oid_q.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) ) - coltype = sqltypes.NULLTYPE - - # If a zero byte or blank string depending on driver (is also absent - # for older PG versions), then not a generated column. Otherwise, s = - # stored. (Other values might be added in the future.) - if generated not in (None, "", b"\x00"): - computed = dict( - sqltext=default, persisted=generated in ("s", b"s") + return oid_q + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("kind", InternalTraversal.dp_plain_obj), + ("scope", InternalTraversal.dp_plain_obj), + ) + def _get_table_oids( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + oid_q = self._table_oids_query(schema, has_filter_names, scope, kind) + result = connection.execute(oid_q, params) + return result.all() + + @util.memoized_property + def _constraint_query(self): + con_sq = ( + select( + pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.conname, + sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( + "attnum" + ), + sql.func.generate_subscripts( + pg_catalog.pg_constraint.c.conkey, 1 + ).label("ord"), ) - default = None - else: - computed = None - - # adjust the default value - autoincrement = False - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - if issubclass(coltype._type_affinity, sqltypes.Integer): - autoincrement = True - # the default is related to a Sequence - sch = schema - if "." not in match.group(2) and sch is not None: - # unconditionally quote the schema name. this could - # later be enhanced to obey quoting rules / - # "quote schema" - default = ( - match.group(1) - + ('"%s"' % sch) - + "." - + match.group(2) - + match.group(3) - ) + .where( + pg_catalog.pg_constraint.c.contype == bindparam("contype"), + pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), + ) + .subquery("con") + ) - column_info = dict( - name=name, - type=coltype, - nullable=nullable, - default=default, - autoincrement=autoincrement or identity is not None, - comment=comment, + attr_sq = ( + select( + con_sq.c.conrelid, + con_sq.c.conname, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + con_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid, + ), + ) + .order_by(con_sq.c.conname, con_sq.c.ord) + .subquery("attr") ) - if computed is not None: - column_info["computed"] = computed - if identity is not None: - column_info["identity"] = identity - return column_info - @reflection.cache - def get_pk_constraint(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + return ( + select( + attr_sq.c.conrelid, + sql.func.array_agg(attr_sq.c.attname).label("cols"), + attr_sq.c.conname, + ) + .group_by(attr_sq.c.conrelid, attr_sq.c.conname) + .order_by(attr_sq.c.conrelid, attr_sq.c.conname) ) - if self.server_version_info < (8, 4): - PK_SQL = """ - SELECT a.attname - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_attribute a - on t.oid=a.attrelid AND %s - WHERE - t.oid = :table_oid and ix.indisprimary = 't' - ORDER BY a.attnum - """ % self._pg_index_any( - "a.attnum", "ix.indkey" + def _reflect_constraint( + self, connection, contype, schema, filter_names, scope, kind, **kw + ): + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._constraint_query, + {"oids": [r[0] for r in batch], "contype": contype}, ) - else: - # unnest() and generate_subscripts() both introduced in - # version 8.4 - PK_SQL = """ - SELECT a.attname - FROM pg_attribute a JOIN ( - SELECT unnest(ix.indkey) attnum, - generate_subscripts(ix.indkey, 1) ord - FROM pg_index ix - WHERE ix.indrelid = :table_oid AND ix.indisprimary - ) k ON a.attnum=k.attnum - WHERE a.attrelid = :table_oid - ORDER BY k.ord - """ - t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) - cols = [r[0] for r in c.fetchall()] - - PK_CONS_SQL = """ - SELECT conname - FROM pg_catalog.pg_constraint r - WHERE r.conrelid = :table_oid AND r.contype = 'p' - ORDER BY 1 - """ - t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) - name = c.scalar() + result_by_oid = defaultdict(list) + for oid, cols, constraint_name in result: + result_by_oid[oid].append((cols, constraint_name)) + + for oid, tablename in batch: + for_oid = result_by_oid.get(oid, ()) + if for_oid: + for cols, constraint in for_oid: + yield tablename, cols, constraint + else: + yield tablename, None, None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def get_multi_pk_constraint( + self, connection, schema, filter_names, scope, kind, **kw + ): + result = self._reflect_constraint( + connection, "p", schema, filter_names, scope, kind, **kw + ) - return {"constrained_columns": cols, "name": name} + # only a single pk can be present for each table. Return an entry + # even if a table has no primary key + default = ReflectionDefaults.pk_constraint + return ( + ( + (schema, table_name), + { + "constrained_columns": [] if cols is None else cols, + "name": pk_name, + } + if pk_name is not None + else default(), + ) + for (table_name, cols, pk_name) in result + ) @reflection.cache def get_foreign_keys( @@ -3993,27 +3673,71 @@ class PGDialect(default.DefaultDialect): postgresql_ignore_search_path=False, **kw, ): - preparer = self.identifier_preparer - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + postgresql_ignore_search_path=postgresql_ignore_search_path, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - FK_SQL = """ - SELECT r.conname, - pg_catalog.pg_get_constraintdef(r.oid, true) as condef, - n.nspname as conschema - FROM pg_catalog.pg_constraint r, - pg_namespace n, - pg_class c - - WHERE r.conrelid = :table AND - r.contype = 'f' AND - c.oid = confrelid AND - n.oid = c.relnamespace - ORDER BY 1 - """ - # https://www.postgresql.org/docs/9.0/static/sql-createtable.html - FK_REGEX = re.compile( + @lru_cache() + def _foreing_key_query(self, schema, has_filter_names, scope, kind): + pg_class_ref = pg_catalog.pg_class.alias("cls_ref") + pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref") + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ), + ), + else_=None, + ), + pg_namespace_ref.c.nspname, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "f", + ), + ) + .outerjoin( + pg_class_ref, + pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid, + ) + .outerjoin( + pg_namespace_ref, + pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid, + ) + .order_by( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + @util.memoized_property + def _fk_regex_pattern(self): + # https://www.postgresql.org/docs/14.0/static/sql-createtable.html + return re.compile( r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" r"[\s]?(ON UPDATE " @@ -4024,12 +3748,33 @@ class PGDialect(default.DefaultDialect): r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" ) - t = sql.text(FK_SQL).columns( - conname=sqltypes.Unicode, condef=sqltypes.Unicode - ) - c = connection.execute(t, dict(table=table_oid)) - fkeys = [] - for conname, condef, conschema in c.fetchall(): + def get_multi_foreign_keys( + self, + connection, + schema, + filter_names, + scope, + kind, + postgresql_ignore_search_path=False, + **kw, + ): + preparer = self.identifier_preparer + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._foreing_key_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + FK_REGEX = self._fk_regex_pattern + + fkeys = defaultdict(list) + default = ReflectionDefaults.foreign_keys + for table_name, conname, condef, conschema in result: + # ensure that each table has an entry, even if it has + # no foreign keys + if conname is None: + fkeys[(schema, table_name)] = default() + continue + table_fks = fkeys[(schema, table_name)] m = re.search(FK_REGEX, condef).groups() ( @@ -4096,317 +3841,406 @@ class PGDialect(default.DefaultDialect): "referred_columns": referred_columns, "options": options, } - fkeys.append(fkey_d) - return fkeys - - def _pg_index_any(self, col, compare_to): - if self.server_version_info < (8, 1): - # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us - # "In CVS tip you could replace this with "attnum = ANY (indkey)". - # Unfortunately, most array support doesn't work on int2vector in - # pre-8.1 releases, so I think you're kinda stuck with the above - # for now. - # regards, tom lane" - return "(%s)" % " OR ".join( - "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10) - ) - else: - return "%s = ANY(%s)" % (col, compare_to) + table_fks.append(fkey_d) + return fkeys.items() @reflection.cache - def get_indexes(self, connection, table_name, schema, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + def get_indexes(self, connection, table_name, schema=None, **kw): + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - # cast indkey as varchar since it's an int2vector, - # returned as a list by some drivers such as pypostgresql - - if self.server_version_info < (8, 5): - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, ix.indpred, - a.attname, a.attnum, NULL, ix.indkey%s, - %s, %s, am.amname, - NULL as indnkeyatts - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and %s - left outer join - pg_am am - on i.relam = am.oid - WHERE - t.relkind IN ('r', 'v', 'f', 'm') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - # version 8.3 here was based on observing the - # cast does not work in PG 8.2.4, does work in 8.3.0. - # nothing in PG changelogs regarding this. - "::varchar" if self.server_version_info >= (8, 3) else "", - "ix.indoption::varchar" - if self.server_version_info >= (8, 3) - else "NULL", - "i.reloptions" - if self.server_version_info >= (8, 2) - else "NULL", - self._pg_index_any("a.attnum", "ix.indkey"), + @util.memoized_property + def _index_query(self): + pg_class_index = pg_catalog.pg_class.alias("cls_idx") + # NOTE: repeating oids clause improve query performance + + # subquery to get the columns + idx_sq = ( + select( + pg_catalog.pg_index.c.indexrelid, + pg_catalog.pg_index.c.indrelid, + sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), + sql.func.generate_subscripts( + pg_catalog.pg_index.c.indkey, 1 + ).label("ord"), ) - else: - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, - a.attname, a.attnum, c.conrelid, ix.indkey::varchar, - ix.indoption::varchar, i.reloptions, am.amname, - pg_get_expr(ix.indpred, ix.indrelid), - %s as indnkeyatts - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and a.attnum = ANY(ix.indkey) - left outer join - pg_constraint c - on (ix.indrelid = c.conrelid and - ix.indexrelid = c.conindid and - c.contype in ('p', 'u', 'x')) - left outer join - pg_am am - on i.relam = am.oid - WHERE - t.relkind IN ('r', 'v', 'f', 'm', 'p') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - "ix.indnkeyatts" - if self.server_version_info >= (11, 0) - else "NULL", + .where( + ~pg_catalog.pg_index.c.indisprimary, + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), ) + .subquery("idx") + ) - t = sql.text(IDX_SQL).columns( - relname=sqltypes.Unicode, attname=sqltypes.Unicode + attr_sq = ( + select( + idx_sq.c.indexrelid, + idx_sq.c.indrelid, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + idx_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, + ), + ) + .where(idx_sq.c.indrelid.in_(bindparam("oids"))) + .order_by(idx_sq.c.indexrelid, idx_sq.c.ord) + .subquery("idx_attr") ) - c = connection.execute(t, dict(table_oid=table_oid)) - indexes = defaultdict(lambda: defaultdict(dict)) + cols_sq = ( + select( + attr_sq.c.indexrelid, + attr_sq.c.indrelid, + sql.func.array_agg(attr_sq.c.attname).label("cols"), + ) + .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid) + .subquery("idx_cols") + ) - sv_idx_name = None - for row in c.fetchall(): - ( - idx_name, - unique, - expr, - col, - col_num, - conrelid, - idx_key, - idx_option, - options, - amname, - filter_definition, - indnkeyatts, - ) = row + if self.server_version_info >= (11, 0): + indnkeyatts = pg_catalog.pg_index.c.indnkeyatts + else: + indnkeyatts = sql.null().label("indnkeyatts") - if expr: - if idx_name != sv_idx_name: - util.warn( - "Skipped unsupported reflection of " - "expression-based index %s" % idx_name - ) - sv_idx_name = idx_name - continue + query = ( + select( + pg_catalog.pg_index.c.indrelid, + pg_class_index.c.relname.label("relname_index"), + pg_catalog.pg_index.c.indisunique, + pg_catalog.pg_index.c.indexprs, + pg_catalog.pg_constraint.c.conrelid.is_not(None).label( + "has_constraint" + ), + pg_catalog.pg_index.c.indoption, + pg_class_index.c.reloptions, + pg_catalog.pg_am.c.amname, + pg_catalog.pg_get_expr( + pg_catalog.pg_index.c.indpred, + pg_catalog.pg_index.c.indrelid, + ).label("filter_definition"), + indnkeyatts, + cols_sq.c.cols.label("index_cols"), + ) + .select_from(pg_catalog.pg_index) + .where( + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), + ~pg_catalog.pg_index.c.indisprimary, + ) + .join( + pg_class_index, + pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid, + ) + .join( + pg_catalog.pg_am, + pg_class_index.c.relam == pg_catalog.pg_am.c.oid, + ) + .outerjoin( + cols_sq, + pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid, + ) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_index.c.indrelid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_index.c.indexrelid + == pg_catalog.pg_constraint.c.conindid, + pg_catalog.pg_constraint.c.contype + == sql.any_(_array.array(("p", "u", "x"))), + ), + ) + .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname) + ) + return query - has_idx = idx_name in indexes - index = indexes[idx_name] - if col is not None: - index["cols"][col_num] = col - if not has_idx: - idx_keys = idx_key.split() - # "The number of key columns in the index, not counting any - # included columns, which are merely stored and do not - # participate in the index semantics" - if indnkeyatts and idx_keys[indnkeyatts:]: - # this is a "covering index" which has INCLUDE columns - # as well as regular index columns - inc_keys = idx_keys[indnkeyatts:] - idx_keys = idx_keys[:indnkeyatts] - else: - inc_keys = [] + def get_multi_indexes( + self, connection, schema, filter_names, scope, kind, **kw + ): - index["key"] = [int(k.strip()) for k in idx_keys] - index["inc"] = [int(k.strip()) for k in inc_keys] + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) - # (new in pg 8.3) - # "pg_index.indoption" is list of ints, one per column/expr. - # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST - sorting = {} - for col_idx, col_flags in enumerate( - (idx_option or "").split() - ): - col_flags = int(col_flags.strip()) - col_sorting = () - # try to set flags only if they differ from PG defaults... - if col_flags & 0x01: - col_sorting += ("desc",) - if not (col_flags & 0x02): - col_sorting += ("nulls_last",) + indexes = defaultdict(list) + default = ReflectionDefaults.indexes + + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._index_query, {"oids": [r[0] for r in batch]} + ).mappings() + + result_by_oid = defaultdict(list) + for row_dict in result: + result_by_oid[row_dict["indrelid"]].append(row_dict) + + for oid, table_name in batch: + if oid not in result_by_oid: + # ensure that each table has an entry, even if reflection + # is skipped because not supported + indexes[(schema, table_name)] = default() + continue + + for row in result_by_oid[oid]: + index_name = row["relname_index"] + + table_indexes = indexes[(schema, table_name)] + + if row["indexprs"]: + tn = ( + table_name + if schema is None + else f"{schema}.{table_name}" + ) + util.warn( + "Skipped unsupported reflection of " + f"expression-based index {index_name} of " + f"table {tn}" + ) + continue + + all_cols = row["index_cols"] + indnkeyatts = row["indnkeyatts"] + # "The number of key columns in the index, not counting any + # included columns, which are merely stored and do not + # participate in the index semantics" + if indnkeyatts and all_cols[indnkeyatts:]: + # this is a "covering index" which has INCLUDE columns + # as well as regular index columns + inc_cols = all_cols[indnkeyatts:] + idx_cols = all_cols[:indnkeyatts] else: - if col_flags & 0x02: - col_sorting += ("nulls_first",) - if col_sorting: - sorting[col_idx] = col_sorting - if sorting: - index["sorting"] = sorting - - index["unique"] = unique - if conrelid is not None: - index["duplicates_constraint"] = idx_name - if options: - index["options"] = dict( - [option.split("=") for option in options] - ) - - # it *might* be nice to include that this is 'btree' in the - # reflection info. But we don't want an Index object - # to have a ``postgresql_using`` in it that is just the - # default, so for the moment leaving this out. - if amname and amname != "btree": - index["amname"] = amname - - if filter_definition: - index["postgresql_where"] = filter_definition + idx_cols = all_cols + inc_cols = [] + + index = { + "name": index_name, + "unique": row["indisunique"], + "column_names": idx_cols, + } + + sorting = {} + for col_index, col_flags in enumerate(row["indoption"]): + col_sorting = () + # try to set flags only if they differ from PG + # defaults... + if col_flags & 0x01: + col_sorting += ("desc",) + if not (col_flags & 0x02): + col_sorting += ("nulls_last",) + else: + if col_flags & 0x02: + col_sorting += ("nulls_first",) + if col_sorting: + sorting[idx_cols[col_index]] = col_sorting + if sorting: + index["column_sorting"] = sorting + if row["has_constraint"]: + index["duplicates_constraint"] = index_name + + dialect_options = {} + if row["reloptions"]: + dialect_options["postgresql_with"] = dict( + [option.split("=") for option in row["reloptions"]] + ) + # it *might* be nice to include that this is 'btree' in the + # reflection info. But we don't want an Index object + # to have a ``postgresql_using`` in it that is just the + # default, so for the moment leaving this out. + amname = row["amname"] + if amname != "btree": + dialect_options["postgresql_using"] = row["amname"] + if row["filter_definition"]: + dialect_options["postgresql_where"] = row[ + "filter_definition" + ] + if self.server_version_info >= (11, 0): + # NOTE: this is legacy, this is part of + # dialect_options now as of #7382 + index["include_columns"] = inc_cols + dialect_options["postgresql_include"] = inc_cols + if dialect_options: + index["dialect_options"] = dialect_options - result = [] - for name, idx in indexes.items(): - entry = { - "name": name, - "unique": idx["unique"], - "column_names": [idx["cols"][i] for i in idx["key"]], - } - if self.server_version_info >= (11, 0): - # NOTE: this is legacy, this is part of dialect_options now - # as of #7382 - entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]] - if "duplicates_constraint" in idx: - entry["duplicates_constraint"] = idx["duplicates_constraint"] - if "sorting" in idx: - entry["column_sorting"] = dict( - (idx["cols"][idx["key"][i]], value) - for i, value in idx["sorting"].items() - ) - if "include_columns" in entry: - entry.setdefault("dialect_options", {})[ - "postgresql_include" - ] = entry["include_columns"] - if "options" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_with" - ] = idx["options"] - if "amname" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_using" - ] = idx["amname"] - if "postgresql_where" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_where" - ] = idx["postgresql_where"] - result.append(entry) - return result + table_indexes.append(index) + return indexes.items() @reflection.cache def get_unique_constraints( self, connection, table_name, schema=None, **kw ): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_unique_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - UNIQUE_SQL = """ - SELECT - cons.conname as name, - cons.conkey as key, - a.attnum as col_num, - a.attname as col_name - FROM - pg_catalog.pg_constraint cons - join pg_attribute a - on cons.conrelid = a.attrelid AND - a.attnum = ANY(cons.conkey) - WHERE - cons.conrelid = :table_oid AND - cons.contype = 'u' - """ - - t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) + def get_multi_unique_constraints( + self, + connection, + schema, + filter_names, + scope, + kind, + **kw, + ): + result = self._reflect_constraint( + connection, "u", schema, filter_names, scope, kind, **kw + ) - uniques = defaultdict(lambda: defaultdict(dict)) - for row in c.fetchall(): - uc = uniques[row.name] - uc["key"] = row.key - uc["cols"][row.col_num] = row.col_name + # each table can have multiple unique constraints + uniques = defaultdict(list) + default = ReflectionDefaults.unique_constraints + for (table_name, cols, con_name) in result: + # ensure a list is created for each table. leave it empty if + # the table has no unique cosntraint + if con_name is None: + uniques[(schema, table_name)] = default() + continue - return [ - {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]} - for name, uc in uniques.items() - ] + uniques[(schema, table_name)].append( + { + "column_names": cols, + "name": con_name, + } + ) + return uniques.items() @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_table_comment( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - COMMENT_SQL = """ - SELECT - pgd.description as table_comment - FROM - pg_catalog.pg_description pgd - WHERE - pgd.objsubid = 0 AND - pgd.objoid = :table_oid - """ + @lru_cache() + def _comment_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_description.c.objoid, + pg_catalog.pg_description.c.objsubid == 0, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query - c = connection.execute( - sql.text(COMMENT_SQL), dict(table_oid=table_oid) + def get_multi_table_comment( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + default = ReflectionDefaults.table_comment + return ( + ( + (schema, table), + {"text": comment} if comment is not None else default(), + ) + for table, comment in result ) - return {"text": c.scalar()} @reflection.cache def get_check_constraints(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_check_constraints( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - CHECK_SQL = """ - SELECT - cons.conname as name, - pg_get_constraintdef(cons.oid) as src - FROM - pg_catalog.pg_constraint cons - WHERE - cons.conrelid = :table_oid AND - cons.contype = 'c' - """ - - c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid)) + @lru_cache() + def _check_constraint_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid + ), + ), + else_=None, + ), + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "c", + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query - ret = [] - for name, src in c: + def get_multi_check_constraints( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._check_constraint_query( + schema, has_filter_names, scope, kind + ) + result = connection.execute(query, params) + + check_constraints = defaultdict(list) + default = ReflectionDefaults.check_constraints + for table_name, check_name, src in result: + # only two cases for check_name and src: both null or both defined + if check_name is None and src is None: + check_constraints[(schema, table_name)] = default() + continue # samples: # "CHECK (((a > 1) AND (a < 5)))" # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))" @@ -4424,84 +4258,118 @@ class PGDialect(default.DefaultDialect): sqltext = re.compile( r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL ).sub(r"\1", m.group(1)) - entry = {"name": name, "sqltext": sqltext} + entry = {"name": check_name, "sqltext": sqltext} if m and m.group(2): entry["dialect_options"] = {"not_valid": True} - ret.append(entry) - return ret - - def _load_enums(self, connection, schema=None): - schema = schema or self.default_schema_name - if not self.supports_native_enum: - return {} - - # Load data types for enums: - SQL_ENUMS = """ - SELECT t.typname as "name", - -- no enum defaults in 8.4 at least - -- t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema", - e.enumlabel as "label" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid - WHERE t.typtype = 'e' - """ + check_constraints[(schema, table_name)].append(entry) + return check_constraints.items() - if schema != "*": - SQL_ENUMS += "AND n.nspname = :schema " + @lru_cache() + def _enum_query(self, schema): + lbl_sq = ( + select( + pg_catalog.pg_enum.c.enumtypid, pg_catalog.pg_enum.c.enumlabel + ) + .order_by( + pg_catalog.pg_enum.c.enumtypid, + pg_catalog.pg_enum.c.enumsortorder, + ) + .subquery("lbl") + ) - # e.oid gives us label order within an enum - SQL_ENUMS += 'ORDER BY "schema", "name", e.oid' + lbl_agg_sq = ( + select( + lbl_sq.c.enumtypid, + sql.func.array_agg(lbl_sq.c.enumlabel).label("labels"), + ) + .group_by(lbl_sq.c.enumtypid) + .subquery("lbl_agg") + ) - s = sql.text(SQL_ENUMS).columns( - attname=sqltypes.Unicode, label=sqltypes.Unicode + query = ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + lbl_agg_sq.c.labels.label("labels"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .outerjoin( + lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid + ) + .where(pg_catalog.pg_type.c.typtype == "e") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) ) - if schema != "*": - s = s.bindparams(schema=schema) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + @reflection.cache + def _load_enums(self, connection, schema=None, **kw): + if not self.supports_native_enum: + return [] - c = connection.execute(s) + result = connection.execute(self._enum_query(schema)) enums = [] - enum_by_name = {} - for enum in c.fetchall(): - key = (enum.schema, enum.name) - if key in enum_by_name: - enum_by_name[key]["labels"].append(enum.label) - else: - enum_by_name[key] = enum_rec = { - "name": enum.name, - "schema": enum.schema, - "visible": enum.visible, - "labels": [], + for name, visible, schema, labels in result: + enums.append( + { + "name": name, + "schema": schema, + "visible": visible, + "labels": [] if labels is None else labels, } - if enum.label is not None: - enum_rec["labels"].append(enum.label) - enums.append(enum_rec) + ) return enums - def _load_domains(self, connection): - # Load data types for domains: - SQL_DOMAINS = """ - SELECT t.typname as "name", - pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", - not t.typnotnull as "nullable", - t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - WHERE t.typtype = 'd' - """ + @util.memoized_property + def _domain_query(self): + return ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_type.c.typbasetype, + pg_catalog.pg_type.c.typtypmod, + ).label("attype"), + (~pg_catalog.pg_type.c.typnotnull).label("nullable"), + pg_catalog.pg_type.c.typdefault.label("default"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .where(pg_catalog.pg_type.c.typtype == "d") + ) - s = sql.text(SQL_DOMAINS) - c = connection.execution_options(future_result=True).execute(s) + @reflection.cache + def _load_domains(self, connection, **kw): + # Load data types for domains: + result = connection.execute(self._domain_query) domains = {} - for domain in c.mappings(): + for domain in result.mappings(): domain = domain # strip (30) from character varying(30) attype = re.search(r"([^\(]+)", domain["attype"]).group(1) |