summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2624
1 files changed, 1246 insertions, 1378 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 83e46151f..36de76e0d 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -11,7 +11,7 @@ r"""
:name: PostgreSQL
:full_support: 9.6, 10, 11, 12, 13, 14
:normal_support: 9.6+
- :best_effort: 8+
+ :best_effort: 9+
.. _postgresql_sequences:
@@ -1448,23 +1448,52 @@ E.g.::
from __future__ import annotations
from collections import defaultdict
-import datetime as dt
+from functools import lru_cache
import re
-from typing import Any
from . import array as _array
from . import dml
from . import hstore as _hstore
from . import json as _json
+from . import pg_catalog
from . import ranges as _ranges
+from .types import _DECIMAL_TYPES # noqa
+from .types import _FLOAT_TYPES # noqa
+from .types import _INT_TYPES # noqa
+from .types import BIT
+from .types import BYTEA
+from .types import CIDR
+from .types import CreateEnumType # noqa
+from .types import DropEnumType # noqa
+from .types import ENUM
+from .types import INET
+from .types import INTERVAL
+from .types import MACADDR
+from .types import MONEY
+from .types import OID
+from .types import PGBit # noqa
+from .types import PGCidr # noqa
+from .types import PGInet # noqa
+from .types import PGInterval # noqa
+from .types import PGMacAddr # noqa
+from .types import PGUuid
+from .types import REGCLASS
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TSVECTOR
from ... import exc
from ... import schema
+from ... import select
from ... import sql
from ... import util
from ...engine import characteristics
from ...engine import default
from ...engine import interfaces
+from ...engine import ObjectKind
+from ...engine import ObjectScope
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
+from ...sql import bindparam
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
@@ -1472,7 +1501,7 @@ from ...sql import expression
from ...sql import roles
from ...sql import sqltypes
from ...sql import util as sql_util
-from ...sql.ddl import InvokeDDLBase
+from ...sql.visitors import InternalTraversal
from ...types import BIGINT
from ...types import BOOLEAN
from ...types import CHAR
@@ -1596,469 +1625,6 @@ RESERVED_WORDS = set(
]
)
-_DECIMAL_TYPES = (1231, 1700)
-_FLOAT_TYPES = (700, 701, 1021, 1022)
-_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
-
-
-class PGUuid(UUID):
- render_bind_cast = True
- render_literal_cast = True
-
-
-class BYTEA(sqltypes.LargeBinary[bytes]):
- __visit_name__ = "BYTEA"
-
-
-class INET(sqltypes.TypeEngine[str]):
- __visit_name__ = "INET"
-
-
-PGInet = INET
-
-
-class CIDR(sqltypes.TypeEngine[str]):
- __visit_name__ = "CIDR"
-
-
-PGCidr = CIDR
-
-
-class MACADDR(sqltypes.TypeEngine[str]):
- __visit_name__ = "MACADDR"
-
-
-PGMacAddr = MACADDR
-
-
-class MONEY(sqltypes.TypeEngine[str]):
-
- r"""Provide the PostgreSQL MONEY type.
-
- Depending on driver, result rows using this type may return a
- string value which includes currency symbols.
-
- For this reason, it may be preferable to provide conversion to a
- numerically-based currency datatype using :class:`_types.TypeDecorator`::
-
- import re
- import decimal
- from sqlalchemy import TypeDecorator
-
- class NumericMoney(TypeDecorator):
- impl = MONEY
-
- def process_result_value(self, value: Any, dialect: Any) -> None:
- if value is not None:
- # adjust this for the currency and numeric
- m = re.match(r"\$([\d.]+)", value)
- if m:
- value = decimal.Decimal(m.group(1))
- return value
-
- Alternatively, the conversion may be applied as a CAST using
- the :meth:`_types.TypeDecorator.column_expression` method as follows::
-
- import decimal
- from sqlalchemy import cast
- from sqlalchemy import TypeDecorator
-
- class NumericMoney(TypeDecorator):
- impl = MONEY
-
- def column_expression(self, column: Any):
- return cast(column, Numeric())
-
- .. versionadded:: 1.2
-
- """
-
- __visit_name__ = "MONEY"
-
-
-class OID(sqltypes.TypeEngine[int]):
-
- """Provide the PostgreSQL OID type.
-
- .. versionadded:: 0.9.5
-
- """
-
- __visit_name__ = "OID"
-
-
-class REGCLASS(sqltypes.TypeEngine[str]):
-
- """Provide the PostgreSQL REGCLASS type.
-
- .. versionadded:: 1.2.7
-
- """
-
- __visit_name__ = "REGCLASS"
-
-
-class TIMESTAMP(sqltypes.TIMESTAMP):
- def __init__(self, timezone=False, precision=None):
- super(TIMESTAMP, self).__init__(timezone=timezone)
- self.precision = precision
-
-
-class TIME(sqltypes.TIME):
- def __init__(self, timezone=False, precision=None):
- super(TIME, self).__init__(timezone=timezone)
- self.precision = precision
-
-
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
-
- """PostgreSQL INTERVAL type."""
-
- __visit_name__ = "INTERVAL"
- native = True
-
- def __init__(self, precision=None, fields=None):
- """Construct an INTERVAL.
-
- :param precision: optional integer precision value
- :param fields: string fields specifier. allows storage of fields
- to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
- etc.
-
- .. versionadded:: 1.2
-
- """
- self.precision = precision
- self.fields = fields
-
- @classmethod
- def adapt_emulated_to_native(cls, interval, **kw):
- return INTERVAL(precision=interval.second_precision)
-
- @property
- def _type_affinity(self):
- return sqltypes.Interval
-
- def as_generic(self, allow_nulltype=False):
- return sqltypes.Interval(native=True, second_precision=self.precision)
-
- @property
- def python_type(self):
- return dt.timedelta
-
-
-PGInterval = INTERVAL
-
-
-class BIT(sqltypes.TypeEngine[int]):
- __visit_name__ = "BIT"
-
- def __init__(self, length=None, varying=False):
- if not varying:
- # BIT without VARYING defaults to length 1
- self.length = length or 1
- else:
- # but BIT VARYING can be unlimited-length, so no default
- self.length = length
- self.varying = varying
-
-
-PGBit = BIT
-
-
-class TSVECTOR(sqltypes.TypeEngine[Any]):
-
- """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
- text search type TSVECTOR.
-
- It can be used to do full text queries on natural language
- documents.
-
- .. versionadded:: 0.9.0
-
- .. seealso::
-
- :ref:`postgresql_match`
-
- """
-
- __visit_name__ = "TSVECTOR"
-
-
-class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
-
- """PostgreSQL ENUM type.
-
- This is a subclass of :class:`_types.Enum` which includes
- support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
-
- When the builtin type :class:`_types.Enum` is used and the
- :paramref:`.Enum.native_enum` flag is left at its default of
- True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
- type as the implementation, so the special create/drop rules
- will be used.
-
- The create/drop behavior of ENUM is necessarily intricate, due to the
- awkward relationship the ENUM type has in relationship to the
- parent table, in that it may be "owned" by just a single table, or
- may be shared among many tables.
-
- When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
- in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
- corresponding to when the :meth:`_schema.Table.create` and
- :meth:`_schema.Table.drop`
- methods are called::
-
- table = Table('sometable', metadata,
- Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
- )
-
- table.create(engine) # will emit CREATE ENUM and CREATE TABLE
- table.drop(engine) # will emit DROP TABLE and DROP ENUM
-
- To use a common enumerated type between multiple tables, the best
- practice is to declare the :class:`_types.Enum` or
- :class:`_postgresql.ENUM` independently, and associate it with the
- :class:`_schema.MetaData` object itself::
-
- my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
-
- t1 = Table('sometable_one', metadata,
- Column('some_enum', myenum)
- )
-
- t2 = Table('sometable_two', metadata,
- Column('some_enum', myenum)
- )
-
- When this pattern is used, care must still be taken at the level
- of individual table creates. Emitting CREATE TABLE without also
- specifying ``checkfirst=True`` will still cause issues::
-
- t1.create(engine) # will fail: no such type 'myenum'
-
- If we specify ``checkfirst=True``, the individual table-level create
- operation will check for the ``ENUM`` and create if not exists::
-
- # will check if enum exists, and emit CREATE TYPE if not
- t1.create(engine, checkfirst=True)
-
- When using a metadata-level ENUM type, the type will always be created
- and dropped if either the metadata-wide create/drop is called::
-
- metadata.create_all(engine) # will emit CREATE TYPE
- metadata.drop_all(engine) # will emit DROP TYPE
-
- The type can also be created and dropped directly::
-
- my_enum.create(engine)
- my_enum.drop(engine)
-
- .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
- now behaves more strictly with regards to CREATE/DROP. A metadata-level
- ENUM type will only be created and dropped at the metadata level,
- not the table level, with the exception of
- ``table.create(checkfirst=True)``.
- The ``table.drop()`` call will now emit a DROP TYPE for a table-level
- enumerated type.
-
- """
-
- native_enum = True
-
- def __init__(self, *enums, **kw):
- """Construct an :class:`_postgresql.ENUM`.
-
- Arguments are the same as that of
- :class:`_types.Enum`, but also including
- the following parameters.
-
- :param create_type: Defaults to True.
- Indicates that ``CREATE TYPE`` should be
- emitted, after optionally checking for the
- presence of the type, when the parent
- table is being created; and additionally
- that ``DROP TYPE`` is called when the table
- is dropped. When ``False``, no check
- will be performed and no ``CREATE TYPE``
- or ``DROP TYPE`` is emitted, unless
- :meth:`~.postgresql.ENUM.create`
- or :meth:`~.postgresql.ENUM.drop`
- are called directly.
- Setting to ``False`` is helpful
- when invoking a creation scheme to a SQL file
- without access to the actual database -
- the :meth:`~.postgresql.ENUM.create` and
- :meth:`~.postgresql.ENUM.drop` methods can
- be used to emit SQL to a target bind.
-
- """
- native_enum = kw.pop("native_enum", None)
- if native_enum is False:
- util.warn(
- "the native_enum flag does not apply to the "
- "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
- "always refers to ENUM. Use sqlalchemy.types.Enum for "
- "non-native enum."
- )
- self.create_type = kw.pop("create_type", True)
- super(ENUM, self).__init__(*enums, **kw)
-
- @classmethod
- def adapt_emulated_to_native(cls, impl, **kw):
- """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
- :class:`.Enum`.
-
- """
- kw.setdefault("validate_strings", impl.validate_strings)
- kw.setdefault("name", impl.name)
- kw.setdefault("schema", impl.schema)
- kw.setdefault("inherit_schema", impl.inherit_schema)
- kw.setdefault("metadata", impl.metadata)
- kw.setdefault("_create_events", False)
- kw.setdefault("values_callable", impl.values_callable)
- kw.setdefault("omit_aliases", impl._omit_aliases)
- return cls(**kw)
-
- def create(self, bind=None, checkfirst=True):
- """Emit ``CREATE TYPE`` for this
- :class:`_postgresql.ENUM`.
-
- If the underlying dialect does not support
- PostgreSQL CREATE TYPE, no action is taken.
-
- :param bind: a connectable :class:`_engine.Engine`,
- :class:`_engine.Connection`, or similar object to emit
- SQL.
- :param checkfirst: if ``True``, a query against
- the PG catalog will be first performed to see
- if the type does not exist already before
- creating.
-
- """
- if not bind.dialect.supports_native_enum:
- return
-
- bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
-
- def drop(self, bind=None, checkfirst=True):
- """Emit ``DROP TYPE`` for this
- :class:`_postgresql.ENUM`.
-
- If the underlying dialect does not support
- PostgreSQL DROP TYPE, no action is taken.
-
- :param bind: a connectable :class:`_engine.Engine`,
- :class:`_engine.Connection`, or similar object to emit
- SQL.
- :param checkfirst: if ``True``, a query against
- the PG catalog will be first performed to see
- if the type actually exists before dropping.
-
- """
- if not bind.dialect.supports_native_enum:
- return
-
- bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
-
- class EnumGenerator(InvokeDDLBase):
- def __init__(self, dialect, connection, checkfirst=False, **kwargs):
- super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
-
- def _can_create_enum(self, enum):
- if not self.checkfirst:
- return True
-
- effective_schema = self.connection.schema_for_object(enum)
-
- return not self.connection.dialect.has_type(
- self.connection, enum.name, schema=effective_schema
- )
-
- def visit_enum(self, enum):
- if not self._can_create_enum(enum):
- return
-
- self.connection.execute(CreateEnumType(enum))
-
- class EnumDropper(InvokeDDLBase):
- def __init__(self, dialect, connection, checkfirst=False, **kwargs):
- super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
-
- def _can_drop_enum(self, enum):
- if not self.checkfirst:
- return True
-
- effective_schema = self.connection.schema_for_object(enum)
-
- return self.connection.dialect.has_type(
- self.connection, enum.name, schema=effective_schema
- )
-
- def visit_enum(self, enum):
- if not self._can_drop_enum(enum):
- return
-
- self.connection.execute(DropEnumType(enum))
-
- def get_dbapi_type(self, dbapi):
- """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
- a different type"""
-
- return None
-
- def _check_for_name_in_memos(self, checkfirst, kw):
- """Look in the 'ddl runner' for 'memos', then
- note our name in that collection.
-
- This to ensure a particular named enum is operated
- upon only once within any kind of create/drop
- sequence without relying upon "checkfirst".
-
- """
- if not self.create_type:
- return True
- if "_ddl_runner" in kw:
- ddl_runner = kw["_ddl_runner"]
- if "_pg_enums" in ddl_runner.memo:
- pg_enums = ddl_runner.memo["_pg_enums"]
- else:
- pg_enums = ddl_runner.memo["_pg_enums"] = set()
- present = (self.schema, self.name) in pg_enums
- pg_enums.add((self.schema, self.name))
- return present
- else:
- return False
-
- def _on_table_create(self, target, bind, checkfirst=False, **kw):
- if (
- checkfirst
- or (
- not self.metadata
- and not kw.get("_is_metadata_operation", False)
- )
- ) and not self._check_for_name_in_memos(checkfirst, kw):
- self.create(bind=bind, checkfirst=checkfirst)
-
- def _on_table_drop(self, target, bind, checkfirst=False, **kw):
- if (
- not self.metadata
- and not kw.get("_is_metadata_operation", False)
- and not self._check_for_name_in_memos(checkfirst, kw)
- ):
- self.drop(bind=bind, checkfirst=checkfirst)
-
- def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
- if not self._check_for_name_in_memos(checkfirst, kw):
- self.create(bind=bind, checkfirst=checkfirst)
-
- def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
- if not self._check_for_name_in_memos(checkfirst, kw):
- self.drop(bind=bind, checkfirst=checkfirst)
-
-
colspecs = {
sqltypes.ARRAY: _array.ARRAY,
sqltypes.Interval: INTERVAL,
@@ -2997,8 +2563,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
class PGInspector(reflection.Inspector):
+ dialect: PGDialect
+
def get_table_oid(self, table_name, schema=None):
- """Return the OID for the given table name."""
+ """Return the OID for the given table name.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
with self._operation_context() as conn:
return self.dialect.get_table_oid(
@@ -3023,9 +2600,10 @@ class PGInspector(reflection.Inspector):
.. versionadded:: 1.0.0
"""
- schema = schema or self.default_schema_name
with self._operation_context() as conn:
- return self.dialect._load_enums(conn, schema)
+ return self.dialect._load_enums(
+ conn, schema, info_cache=self.info_cache
+ )
def get_foreign_table_names(self, schema=None):
"""Return a list of FOREIGN TABLE names.
@@ -3038,38 +2616,29 @@ class PGInspector(reflection.Inspector):
.. versionadded:: 1.0.0
"""
- schema = schema or self.default_schema_name
with self._operation_context() as conn:
- return self.dialect._get_foreign_table_names(conn, schema)
-
- def get_view_names(self, schema=None, include=("plain", "materialized")):
- """Return all view names in `schema`.
+ return self.dialect._get_foreign_table_names(
+ conn, schema, info_cache=self.info_cache
+ )
- :param schema: Optional, retrieve names from a non-default schema.
- For special quoting, use :class:`.quoted_name`.
+ def has_type(self, type_name, schema=None, **kw):
+ """Return if the database has the specified type in the provided
+ schema.
- :param include: specify which types of views to return. Passed
- as a string value (for a single type) or a tuple (for any number
- of types). Defaults to ``('plain', 'materialized')``.
+ :param type_name: the type to check.
+ :param schema: schema name. If None, the default schema
+ (typically 'public') is used. May also be set to '*' to
+ check in all schemas.
- .. versionadded:: 1.1
+ .. versionadded:: 2.0
"""
-
with self._operation_context() as conn:
- return self.dialect.get_view_names(
- conn, schema, info_cache=self.info_cache, include=include
+ return self.dialect.has_type(
+ conn, type_name, schema, info_cache=self.info_cache
)
-class CreateEnumType(schema._CreateDropBase):
- __visit_name__ = "create_enum_type"
-
-
-class DropEnumType(schema._CreateDropBase):
- __visit_name__ = "drop_enum_type"
-
-
class PGExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
return self._execute_scalar(
@@ -3262,35 +2831,14 @@ class PGDialect(default.DefaultDialect):
def initialize(self, connection):
super(PGDialect, self).initialize(connection)
- if self.server_version_info <= (8, 2):
- self.delete_returning = (
- self.update_returning
- ) = self.insert_returning = False
-
- self.supports_native_enum = self.server_version_info >= (8, 3)
- if not self.supports_native_enum:
- self.colspecs = self.colspecs.copy()
- # pop base Enum type
- self.colspecs.pop(sqltypes.Enum, None)
- # psycopg2, others may have placed ENUM here as well
- self.colspecs.pop(ENUM, None)
-
# https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
self.supports_smallserial = self.server_version_info >= (9, 2)
- if self.server_version_info < (8, 2):
- self._backslash_escapes = False
- else:
- # ensure this query is not emitted on server version < 8.2
- # as it will fail
- std_string = connection.exec_driver_sql(
- "show standard_conforming_strings"
- ).scalar()
- self._backslash_escapes = std_string == "off"
-
- self._supports_create_index_concurrently = (
- self.server_version_info >= (8, 2)
- )
+ std_string = connection.exec_driver_sql(
+ "show standard_conforming_strings"
+ ).scalar()
+ self._backslash_escapes = std_string == "off"
+
self._supports_drop_index_concurrently = self.server_version_info >= (
9,
2,
@@ -3370,122 +2918,100 @@ class PGDialect(default.DefaultDialect):
self.do_commit(connection.connection)
def do_recover_twophase(self, connection):
- resultset = connection.execute(
+ return connection.scalars(
sql.text("SELECT gid FROM pg_prepared_xacts")
- )
- return [row[0] for row in resultset]
+ ).all()
def _get_default_schema_name(self, connection):
return connection.exec_driver_sql("select current_schema()").scalar()
- def has_schema(self, connection, schema):
- query = (
- "select nspname from pg_namespace " "where lower(nspname)=:schema"
- )
- cursor = connection.execute(
- sql.text(query).bindparams(
- sql.bindparam(
- "schema",
- str(schema.lower()),
- type_=sqltypes.Unicode,
- )
- )
+ @reflection.cache
+ def has_schema(self, connection, schema, **kw):
+ query = select(pg_catalog.pg_namespace.c.nspname).where(
+ pg_catalog.pg_namespace.c.nspname == schema
)
+ return bool(connection.scalar(query))
- return bool(cursor.first())
-
- def has_table(self, connection, table_name, schema=None):
- self._ensure_has_table_connection(connection)
- # seems like case gets folded in pg_class...
+ def _pg_class_filter_scope_schema(
+ self, query, schema, scope, pg_class_table=None
+ ):
+ if pg_class_table is None:
+ pg_class_table = pg_catalog.pg_class
+ query = query.join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace,
+ )
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(pg_class_table.c.relpersistence != "t")
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(pg_class_table.c.relpersistence == "t")
if schema is None:
- cursor = connection.execute(
- sql.text(
- "select relname from pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where "
- "pg_catalog.pg_table_is_visible(c.oid) "
- "and relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(table_name),
- type_=sqltypes.Unicode,
- )
- )
+ query = query.where(
+ pg_catalog.pg_table_is_visible(pg_class_table.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
)
else:
- cursor = connection.execute(
- sql.text(
- "select relname from pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where n.nspname=:schema and "
- "relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(table_name),
- type_=sqltypes.Unicode,
- ),
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
- )
- return bool(cursor.first())
-
- def has_sequence(self, connection, sequence_name, schema=None):
- if schema is None:
- schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT relname FROM pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where relkind='S' and "
- "n.nspname=:schema and relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(sequence_name),
- type_=sqltypes.Unicode,
- ),
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
+ def _pg_class_relkind_condition(self, relkinds, pg_class_table=None):
+ if pg_class_table is None:
+ pg_class_table = pg_catalog.pg_class
+ # uses the any form instead of in otherwise postgresql complaings
+ # that 'IN could not convert type character to "char"'
+ return pg_class_table.c.relkind == sql.any_(_array.array(relkinds))
+
+ @lru_cache()
+ def _has_table_query(self, schema):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ pg_catalog.pg_class.c.relname == bindparam("table_name"),
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ ),
+ )
+ return self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
)
- return bool(cursor.first())
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
+ self._ensure_has_table_connection(connection)
+ query = self._has_table_query(schema)
+ return bool(connection.scalar(query, {"table_name": table_name}))
- def has_type(self, connection, type_name, schema=None):
- if schema is not None:
- query = """
- SELECT EXISTS (
- SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n
- WHERE t.typnamespace = n.oid
- AND t.typname = :typname
- AND n.nspname = :nspname
- )
- """
- query = sql.text(query)
- else:
- query = """
- SELECT EXISTS (
- SELECT * FROM pg_catalog.pg_type t
- WHERE t.typname = :typname
- AND pg_type_is_visible(t.oid)
- )
- """
- query = sql.text(query)
- query = query.bindparams(
- sql.bindparam("typname", str(type_name), type_=sqltypes.Unicode)
+ @reflection.cache
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ pg_catalog.pg_class.c.relkind == "S",
+ pg_catalog.pg_class.c.relname == sequence_name,
)
- if schema is not None:
- query = query.bindparams(
- sql.bindparam("nspname", str(schema), type_=sqltypes.Unicode)
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ return bool(connection.scalar(query))
+
+ @reflection.cache
+ def has_type(self, connection, type_name, schema=None, **kw):
+ query = (
+ select(pg_catalog.pg_type.c.typname)
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
)
- cursor = connection.execute(query)
- return bool(cursor.scalar())
+ .where(pg_catalog.pg_type.c.typname == type_name)
+ )
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+
+ return bool(connection.scalar(query))
def _get_server_version_info(self, connection):
v = connection.exec_driver_sql("select pg_catalog.version()").scalar()
@@ -3502,229 +3028,300 @@ class PGDialect(default.DefaultDialect):
@reflection.cache
def get_table_oid(self, connection, table_name, schema=None, **kw):
- """Fetch the oid for schema.table_name.
-
- Several reflection methods require the table oid. The idea for using
- this method is that it can be fetched one time and cached for
- subsequent calls.
-
- """
- table_oid = None
- if schema is not None:
- schema_where_clause = "n.nspname = :schema"
- else:
- schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
- query = (
- """
- SELECT c.oid
- FROM pg_catalog.pg_class c
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
- WHERE (%s)
- AND c.relname = :table_name AND c.relkind in
- ('r', 'v', 'm', 'f', 'p')
- """
- % schema_where_clause
+ """Fetch the oid for schema.table_name."""
+ query = select(pg_catalog.pg_class.c.oid).where(
+ pg_catalog.pg_class.c.relname == table_name,
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ ),
)
- # Since we're binding to unicode, table_name and schema_name must be
- # unicode.
- table_name = str(table_name)
- if schema is not None:
- schema = str(schema)
- s = sql.text(query).bindparams(table_name=sqltypes.Unicode)
- s = s.columns(oid=sqltypes.Integer)
- if schema:
- s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode))
- c = connection.execute(s, dict(table_name=table_name, schema=schema))
- table_oid = c.scalar()
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ table_oid = connection.scalar(query)
if table_oid is None:
- raise exc.NoSuchTableError(table_name)
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
return table_oid
@reflection.cache
def get_schema_names(self, connection, **kw):
- result = connection.execute(
- sql.text(
- "SELECT nspname FROM pg_namespace "
- "WHERE nspname NOT LIKE 'pg_%' "
- "ORDER BY nspname"
- ).columns(nspname=sqltypes.Unicode)
+ query = (
+ select(pg_catalog.pg_namespace.c.nspname)
+ .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%"))
+ .order_by(pg_catalog.pg_namespace.c.nspname)
+ )
+ return connection.scalars(query).all()
+
+ def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ self._pg_class_relkind_condition(relkinds)
)
- return [name for name, in result]
+ query = self._pg_class_filter_scope_schema(query, schema, scope=scope)
+ return connection.scalars(query).all()
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_TABLE_NO_FOREIGN,
+ scope=ObjectScope.DEFAULT,
+ )
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema=None,
+ relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN,
+ scope=ObjectScope.TEMPORARY,
)
- return [name for name, in result]
@reflection.cache
def _get_foreign_table_names(self, connection, schema=None, **kw):
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind = 'f'"
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ return self._get_relnames_for_relkinds(
+ connection, schema, relkinds=("f",), scope=ObjectScope.ANY
)
- return [name for name, in result]
@reflection.cache
- def get_view_names(
- self, connection, schema=None, include=("plain", "materialized"), **kw
- ):
+ def get_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_VIEW,
+ scope=ObjectScope.DEFAULT,
+ )
- include_kind = {"plain": "v", "materialized": "m"}
- try:
- kinds = [include_kind[i] for i in util.to_list(include)]
- except KeyError:
- raise ValueError(
- "include %r unknown, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'" % (include,)
- )
- if not kinds:
- raise ValueError(
- "empty include, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'"
- )
+ @reflection.cache
+ def get_materialized_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_MAT_VIEW,
+ scope=ObjectScope.DEFAULT,
+ )
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind IN (%s)"
- % (", ".join("'%s'" % elem for elem in kinds))
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ @reflection.cache
+ def get_temp_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ # NOTE: do not include temp materialzied views (that do not
+ # seem to be a thing at least up to version 14)
+ pg_catalog.RELKINDS_VIEW,
+ scope=ObjectScope.TEMPORARY,
)
- return [name for name, in result]
@reflection.cache
def get_sequence_names(self, connection, schema=None, **kw):
- if not schema:
- schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT relname FROM pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where relkind='S' and "
- "n.nspname=:schema"
- ).bindparams(
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
+ return self._get_relnames_for_relkinds(
+ connection, schema, relkinds=("S",), scope=ObjectScope.ANY
)
- return [row[0] for row in cursor]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
- view_def = connection.scalar(
- sql.text(
- "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relname = :view_name "
- "AND c.relkind IN ('v', 'm')"
- ).columns(view_def=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name,
- view_name=view_name,
- ),
+ query = (
+ select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid))
+ .select_from(pg_catalog.pg_class)
+ .where(
+ pg_catalog.pg_class.c.relname == view_name,
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW
+ ),
+ )
)
- return view_def
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ res = connection.scalar(query)
+ if res is None:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
+ else:
+ return res
+
+ def _value_or_raise(self, data, table, schema):
+ try:
+ return dict(data)[(schema, table)]
+ except KeyError:
+ raise exc.NoSuchTableError(
+ f"{schema}.{table}" if schema else table
+ ) from None
+
+ def _prepare_filter_names(self, filter_names):
+ if filter_names:
+ return True, {"filter_names": filter_names}
+ else:
+ return False, {}
+
+ def _kind_to_relkinds(self, kind: ObjectKind) -> tuple[str, ...]:
+ if kind is ObjectKind.ANY:
+ return pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ relkinds = ()
+ if ObjectKind.TABLE in kind:
+ relkinds += pg_catalog.RELKINDS_TABLE
+ if ObjectKind.VIEW in kind:
+ relkinds += pg_catalog.RELKINDS_VIEW
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ relkinds += pg_catalog.RELKINDS_MAT_VIEW
+ return relkinds
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
-
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_columns(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
+ @lru_cache()
+ def _columns_query(self, schema, has_filter_names, scope, kind):
+ # NOTE: the query with the default and identity options scalar
+ # subquery is faster than trying to use outer joins for them
generated = (
- "a.attgenerated as generated"
+ pg_catalog.pg_attribute.c.attgenerated.label("generated")
if self.server_version_info >= (12,)
- else "NULL as generated"
+ else sql.null().label("generated")
)
if self.server_version_info >= (10,):
- # a.attidentity != '' is required or it will reflect also
- # serial columns as identity.
- identity = """\
- (SELECT json_build_object(
- 'always', a.attidentity = 'a',
- 'start', s.seqstart,
- 'increment', s.seqincrement,
- 'minvalue', s.seqmin,
- 'maxvalue', s.seqmax,
- 'cache', s.seqcache,
- 'cycle', s.seqcycle)
- FROM pg_catalog.pg_sequence s
- JOIN pg_catalog.pg_class c on s.seqrelid = c."oid"
- WHERE c.relkind = 'S'
- AND a.attidentity != ''
- AND s.seqrelid = pg_catalog.pg_get_serial_sequence(
- a.attrelid::regclass::text, a.attname
- )::regclass::oid
- ) as identity_options\
- """
+ # join lateral performs worse (~2x slower) than a scalar_subquery
+ identity = (
+ select(
+ sql.func.json_build_object(
+ "always",
+ pg_catalog.pg_attribute.c.attidentity == "a",
+ "start",
+ pg_catalog.pg_sequence.c.seqstart,
+ "increment",
+ pg_catalog.pg_sequence.c.seqincrement,
+ "minvalue",
+ pg_catalog.pg_sequence.c.seqmin,
+ "maxvalue",
+ pg_catalog.pg_sequence.c.seqmax,
+ "cache",
+ pg_catalog.pg_sequence.c.seqcache,
+ "cycle",
+ pg_catalog.pg_sequence.c.seqcycle,
+ )
+ )
+ .select_from(pg_catalog.pg_sequence)
+ .where(
+ # attidentity != '' is required or it will reflect also
+ # serial columns as identity.
+ pg_catalog.pg_attribute.c.attidentity != "",
+ pg_catalog.pg_sequence.c.seqrelid
+ == sql.cast(
+ sql.cast(
+ pg_catalog.pg_get_serial_sequence(
+ sql.cast(
+ sql.cast(
+ pg_catalog.pg_attribute.c.attrelid,
+ REGCLASS,
+ ),
+ TEXT,
+ ),
+ pg_catalog.pg_attribute.c.attname,
+ ),
+ REGCLASS,
+ ),
+ OID,
+ ),
+ )
+ .correlate(pg_catalog.pg_attribute)
+ .scalar_subquery()
+ .label("identity_options")
+ )
else:
- identity = "NULL as identity_options"
-
- SQL_COLS = """
- SELECT a.attname,
- pg_catalog.format_type(a.atttypid, a.atttypmod),
- (
- SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid)
- FROM pg_catalog.pg_attrdef d
- WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum
- AND a.atthasdef
- ) AS DEFAULT,
- a.attnotnull,
- a.attrelid as table_oid,
- pgd.description as comment,
- %s,
- %s
- FROM pg_catalog.pg_attribute a
- LEFT JOIN pg_catalog.pg_description pgd ON (
- pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum)
- WHERE a.attrelid = :table_oid
- AND a.attnum > 0 AND NOT a.attisdropped
- ORDER BY a.attnum
- """ % (
- generated,
- identity,
+ identity = sql.null().label("identity_options")
+
+ # join lateral performs the same as scalar_subquery here
+ default = (
+ select(
+ pg_catalog.pg_get_expr(
+ pg_catalog.pg_attrdef.c.adbin,
+ pg_catalog.pg_attrdef.c.adrelid,
+ )
+ )
+ .select_from(pg_catalog.pg_attrdef)
+ .where(
+ pg_catalog.pg_attrdef.c.adrelid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_attrdef.c.adnum
+ == pg_catalog.pg_attribute.c.attnum,
+ pg_catalog.pg_attribute.c.atthasdef,
+ )
+ .correlate(pg_catalog.pg_attribute)
+ .scalar_subquery()
+ .label("default")
)
- s = (
- sql.text(SQL_COLS)
- .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
- .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_attribute.c.attname.label("name"),
+ pg_catalog.format_type(
+ pg_catalog.pg_attribute.c.atttypid,
+ pg_catalog.pg_attribute.c.atttypmod,
+ ).label("format_type"),
+ default,
+ pg_catalog.pg_attribute.c.attnotnull.label("not_null"),
+ pg_catalog.pg_class.c.relname.label("table_name"),
+ pg_catalog.pg_description.c.description.label("comment"),
+ generated,
+ identity,
+ )
+ .select_from(pg_catalog.pg_class)
+ # NOTE: postgresql support table with no user column, meaning
+ # there is no row with pg_attribute.attnum > 0. use a left outer
+ # join to avoid filtering these tables.
+ .outerjoin(
+ pg_catalog.pg_attribute,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_attribute.c.attnum > 0,
+ ~pg_catalog.pg_attribute.c.attisdropped,
+ ),
+ )
+ .outerjoin(
+ pg_catalog.pg_description,
+ sql.and_(
+ pg_catalog.pg_description.c.objoid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_description.c.objsubid
+ == pg_catalog.pg_attribute.c.attnum,
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ .order_by(
+ pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum
+ )
)
- c = connection.execute(s, dict(table_oid=table_oid))
- rows = c.fetchall()
+ query = self._pg_class_filter_scope_schema(query, schema, scope=scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
+
+ def get_multi_columns(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._columns_query(schema, has_filter_names, scope, kind)
+ rows = connection.execute(query, params).mappings()
# dictionary with (name, ) if default search path or (schema, name)
# as keys
- domains = self._load_domains(connection)
+ domains = self._load_domains(
+ connection, info_cache=kw.get("info_cache")
+ )
# dictionary with (name, ) if default search path or (schema, name)
# as keys
@@ -3732,257 +3329,340 @@ class PGDialect(default.DefaultDialect):
((rec["name"],), rec)
if rec["visible"]
else ((rec["schema"], rec["name"]), rec)
- for rec in self._load_enums(connection, schema="*")
+ for rec in self._load_enums(
+ connection, schema="*", info_cache=kw.get("info_cache")
+ )
)
- # format columns
- columns = []
-
- for (
- name,
- format_type,
- default_,
- notnull,
- table_oid,
- comment,
- generated,
- identity,
- ) in rows:
- column_info = self._get_column_info(
- name,
- format_type,
- default_,
- notnull,
- domains,
- enums,
- schema,
- comment,
- generated,
- identity,
- )
- columns.append(column_info)
- return columns
+ columns = self._get_columns_info(rows, domains, enums, schema)
+
+ return columns.items()
+
+ def _get_columns_info(self, rows, domains, enums, schema):
+ array_type_pattern = re.compile(r"\[\]$")
+ attype_pattern = re.compile(r"\(.*\)")
+ charlen_pattern = re.compile(r"\(([\d,]+)\)")
+ args_pattern = re.compile(r"\((.*)\)")
+ args_split_pattern = re.compile(r"\s*,\s*")
- def _get_column_info(
- self,
- name,
- format_type,
- default,
- notnull,
- domains,
- enums,
- schema,
- comment,
- generated,
- identity,
- ):
def _handle_array_type(attype):
return (
# strip '[]' from integer[], etc.
- re.sub(r"\[\]$", "", attype),
+ array_type_pattern.sub("", attype),
attype.endswith("[]"),
)
- # strip (*) from character varying(5), timestamp(5)
- # with time zone, geometry(POLYGON), etc.
- attype = re.sub(r"\(.*\)", "", format_type)
+ columns = defaultdict(list)
+ for row_dict in rows:
+ # ensure that each table has an entry, even if it has no columns
+ if row_dict["name"] is None:
+ columns[
+ (schema, row_dict["table_name"])
+ ] = ReflectionDefaults.columns()
+ continue
+ table_cols = columns[(schema, row_dict["table_name"])]
- # strip '[]' from integer[], etc. and check if an array
- attype, is_array = _handle_array_type(attype)
+ format_type = row_dict["format_type"]
+ default = row_dict["default"]
+ name = row_dict["name"]
+ generated = row_dict["generated"]
+ identity = row_dict["identity_options"]
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+ # strip (*) from character varying(5), timestamp(5)
+ # with time zone, geometry(POLYGON), etc.
+ attype = attype_pattern.sub("", format_type)
- nullable = not notnull
+ # strip '[]' from integer[], etc. and check if an array
+ attype, is_array = _handle_array_type(attype)
- charlen = re.search(r"\(([\d,]+)\)", format_type)
- if charlen:
- charlen = charlen.group(1)
- args = re.search(r"\((.*)\)", format_type)
- if args and args.group(1):
- args = tuple(re.split(r"\s*,\s*", args.group(1)))
- else:
- args = ()
- kwargs = {}
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
+ nullable = not row_dict["not_null"]
- if attype == "numeric":
+ charlen = charlen_pattern.search(format_type)
if charlen:
- prec, scale = charlen.split(",")
- args = (int(prec), int(scale))
+ charlen = charlen.group(1)
+ args = args_pattern.search(format_type)
+ if args and args.group(1):
+ args = tuple(args_split_pattern.split(args.group(1)))
else:
args = ()
- elif attype == "double precision":
- args = (53,)
- elif attype == "integer":
- args = ()
- elif attype in ("timestamp with time zone", "time with time zone"):
- kwargs["timezone"] = True
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype in (
- "timestamp without time zone",
- "time without time zone",
- "time",
- ):
- kwargs["timezone"] = False
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype == "bit varying":
- kwargs["varying"] = True
- if charlen:
+ kwargs = {}
+
+ if attype == "numeric":
+ if charlen:
+ prec, scale = charlen.split(",")
+ args = (int(prec), int(scale))
+ else:
+ args = ()
+ elif attype == "double precision":
+ args = (53,)
+ elif attype == "integer":
+ args = ()
+ elif attype in ("timestamp with time zone", "time with time zone"):
+ kwargs["timezone"] = True
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype in (
+ "timestamp without time zone",
+ "time without time zone",
+ "time",
+ ):
+ kwargs["timezone"] = False
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype == "bit varying":
+ kwargs["varying"] = True
+ if charlen:
+ args = (int(charlen),)
+ else:
+ args = ()
+ elif attype.startswith("interval"):
+ field_match = re.match(r"interval (.+)", attype, re.I)
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ if field_match:
+ kwargs["fields"] = field_match.group(1)
+ attype = "interval"
+ args = ()
+ elif charlen:
args = (int(charlen),)
+
+ while True:
+ # looping here to suit nested domains
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
+ break
+ elif enum_or_domain_key in enums:
+ enum = enums[enum_or_domain_key]
+ coltype = ENUM
+ kwargs["name"] = enum["name"]
+ if not enum["visible"]:
+ kwargs["schema"] = enum["schema"]
+ args = tuple(enum["labels"])
+ break
+ elif enum_or_domain_key in domains:
+ domain = domains[enum_or_domain_key]
+ attype = domain["attype"]
+ attype, is_array = _handle_array_type(attype)
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(
+ util.quoted_token_parser(attype)
+ )
+ # A table can't override a not null on the domain,
+ # but can override nullable
+ nullable = nullable and domain["nullable"]
+ if domain["default"] and not default:
+ # It can, however, override the default
+ # value, but can't set it to null.
+ default = domain["default"]
+ continue
+ else:
+ coltype = None
+ break
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = self.ischema_names["_array"](coltype)
else:
- args = ()
- elif attype.startswith("interval"):
- field_match = re.match(r"interval (.+)", attype, re.I)
- if charlen:
- kwargs["precision"] = int(charlen)
- if field_match:
- kwargs["fields"] = field_match.group(1)
- attype = "interval"
- args = ()
- elif charlen:
- args = (int(charlen),)
-
- while True:
- # looping here to suit nested domains
- if attype in self.ischema_names:
- coltype = self.ischema_names[attype]
- break
- elif enum_or_domain_key in enums:
- enum = enums[enum_or_domain_key]
- coltype = ENUM
- kwargs["name"] = enum["name"]
- if not enum["visible"]:
- kwargs["schema"] = enum["schema"]
- args = tuple(enum["labels"])
- break
- elif enum_or_domain_key in domains:
- domain = domains[enum_or_domain_key]
- attype = domain["attype"]
- attype, is_array = _handle_array_type(attype)
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
- # A table can't override a not null on the domain,
- # but can override nullable
- nullable = nullable and domain["nullable"]
- if domain["default"] and not default:
- # It can, however, override the default
- # value, but can't set it to null.
- default = domain["default"]
- continue
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (attype, name)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ # If a zero byte or blank string depending on driver (is also
+ # absent for older PG versions), then not a generated column.
+ # Otherwise, s = stored. (Other values might be added in the
+ # future.)
+ if generated not in (None, "", b"\x00"):
+ computed = dict(
+ sqltext=default, persisted=generated in ("s", b"s")
+ )
+ default = None
else:
- coltype = None
- break
+ computed = None
- if coltype:
- coltype = coltype(*args, **kwargs)
- if is_array:
- coltype = self.ischema_names["_array"](coltype)
- else:
- util.warn(
- "Did not recognize type '%s' of column '%s'" % (attype, name)
+ # adjust the default value
+ autoincrement = False
+ if default is not None:
+ match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+ if match is not None:
+ if issubclass(coltype._type_affinity, sqltypes.Integer):
+ autoincrement = True
+ # the default is related to a Sequence
+ if "." not in match.group(2) and schema is not None:
+ # unconditionally quote the schema name. this could
+ # later be enhanced to obey quoting rules /
+ # "quote schema"
+ default = (
+ match.group(1)
+ + ('"%s"' % schema)
+ + "."
+ + match.group(2)
+ + match.group(3)
+ )
+
+ column_info = {
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": autoincrement or identity is not None,
+ "comment": row_dict["comment"],
+ }
+ if computed is not None:
+ column_info["computed"] = computed
+ if identity is not None:
+ column_info["identity"] = identity
+
+ table_cols.append(column_info)
+
+ return columns
+
+ @lru_cache()
+ def _table_oids_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ oid_q = select(
+ pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname
+ ).where(self._pg_class_relkind_condition(relkinds))
+ oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope)
+
+ if has_filter_names:
+ oid_q = oid_q.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
)
- coltype = sqltypes.NULLTYPE
-
- # If a zero byte or blank string depending on driver (is also absent
- # for older PG versions), then not a generated column. Otherwise, s =
- # stored. (Other values might be added in the future.)
- if generated not in (None, "", b"\x00"):
- computed = dict(
- sqltext=default, persisted=generated in ("s", b"s")
+ return oid_q
+
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("kind", InternalTraversal.dp_plain_obj),
+ ("scope", InternalTraversal.dp_plain_obj),
+ )
+ def _get_table_oids(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ oid_q = self._table_oids_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(oid_q, params)
+ return result.all()
+
+ @util.memoized_property
+ def _constraint_query(self):
+ con_sq = (
+ select(
+ pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.conname,
+ sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label(
+ "attnum"
+ ),
+ sql.func.generate_subscripts(
+ pg_catalog.pg_constraint.c.conkey, 1
+ ).label("ord"),
)
- default = None
- else:
- computed = None
-
- # adjust the default value
- autoincrement = False
- if default is not None:
- match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
- if match is not None:
- if issubclass(coltype._type_affinity, sqltypes.Integer):
- autoincrement = True
- # the default is related to a Sequence
- sch = schema
- if "." not in match.group(2) and sch is not None:
- # unconditionally quote the schema name. this could
- # later be enhanced to obey quoting rules /
- # "quote schema"
- default = (
- match.group(1)
- + ('"%s"' % sch)
- + "."
- + match.group(2)
- + match.group(3)
- )
+ .where(
+ pg_catalog.pg_constraint.c.contype == bindparam("contype"),
+ pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")),
+ )
+ .subquery("con")
+ )
- column_info = dict(
- name=name,
- type=coltype,
- nullable=nullable,
- default=default,
- autoincrement=autoincrement or identity is not None,
- comment=comment,
+ attr_sq = (
+ select(
+ con_sq.c.conrelid,
+ con_sq.c.conname,
+ pg_catalog.pg_attribute.c.attname,
+ )
+ .select_from(pg_catalog.pg_attribute)
+ .join(
+ con_sq,
+ sql.and_(
+ pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum,
+ pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid,
+ ),
+ )
+ .order_by(con_sq.c.conname, con_sq.c.ord)
+ .subquery("attr")
)
- if computed is not None:
- column_info["computed"] = computed
- if identity is not None:
- column_info["identity"] = identity
- return column_info
- @reflection.cache
- def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ return (
+ select(
+ attr_sq.c.conrelid,
+ sql.func.array_agg(attr_sq.c.attname).label("cols"),
+ attr_sq.c.conname,
+ )
+ .group_by(attr_sq.c.conrelid, attr_sq.c.conname)
+ .order_by(attr_sq.c.conrelid, attr_sq.c.conname)
)
- if self.server_version_info < (8, 4):
- PK_SQL = """
- SELECT a.attname
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_attribute a
- on t.oid=a.attrelid AND %s
- WHERE
- t.oid = :table_oid and ix.indisprimary = 't'
- ORDER BY a.attnum
- """ % self._pg_index_any(
- "a.attnum", "ix.indkey"
+ def _reflect_constraint(
+ self, connection, contype, schema, filter_names, scope, kind, **kw
+ ):
+ table_oids = self._get_table_oids(
+ connection, schema, filter_names, scope, kind, **kw
+ )
+ batches = list(table_oids)
+
+ while batches:
+ batch = batches[0:3000]
+ batches[0:3000] = []
+
+ result = connection.execute(
+ self._constraint_query,
+ {"oids": [r[0] for r in batch], "contype": contype},
)
- else:
- # unnest() and generate_subscripts() both introduced in
- # version 8.4
- PK_SQL = """
- SELECT a.attname
- FROM pg_attribute a JOIN (
- SELECT unnest(ix.indkey) attnum,
- generate_subscripts(ix.indkey, 1) ord
- FROM pg_index ix
- WHERE ix.indrelid = :table_oid AND ix.indisprimary
- ) k ON a.attnum=k.attnum
- WHERE a.attrelid = :table_oid
- ORDER BY k.ord
- """
- t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
- cols = [r[0] for r in c.fetchall()]
-
- PK_CONS_SQL = """
- SELECT conname
- FROM pg_catalog.pg_constraint r
- WHERE r.conrelid = :table_oid AND r.contype = 'p'
- ORDER BY 1
- """
- t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
- name = c.scalar()
+ result_by_oid = defaultdict(list)
+ for oid, cols, constraint_name in result:
+ result_by_oid[oid].append((cols, constraint_name))
+
+ for oid, tablename in batch:
+ for_oid = result_by_oid.get(oid, ())
+ if for_oid:
+ for cols, constraint in for_oid:
+ yield tablename, cols, constraint
+ else:
+ yield tablename, None, None
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ data = self.get_multi_pk_constraint(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ def get_multi_pk_constraint(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ result = self._reflect_constraint(
+ connection, "p", schema, filter_names, scope, kind, **kw
+ )
- return {"constrained_columns": cols, "name": name}
+ # only a single pk can be present for each table. Return an entry
+ # even if a table has no primary key
+ default = ReflectionDefaults.pk_constraint
+ return (
+ (
+ (schema, table_name),
+ {
+ "constrained_columns": [] if cols is None else cols,
+ "name": pk_name,
+ }
+ if pk_name is not None
+ else default(),
+ )
+ for (table_name, cols, pk_name) in result
+ )
@reflection.cache
def get_foreign_keys(
@@ -3993,27 +3673,71 @@ class PGDialect(default.DefaultDialect):
postgresql_ignore_search_path=False,
**kw,
):
- preparer = self.identifier_preparer
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_foreign_keys(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ postgresql_ignore_search_path=postgresql_ignore_search_path,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- FK_SQL = """
- SELECT r.conname,
- pg_catalog.pg_get_constraintdef(r.oid, true) as condef,
- n.nspname as conschema
- FROM pg_catalog.pg_constraint r,
- pg_namespace n,
- pg_class c
-
- WHERE r.conrelid = :table AND
- r.contype = 'f' AND
- c.oid = confrelid AND
- n.oid = c.relnamespace
- ORDER BY 1
- """
- # https://www.postgresql.org/docs/9.0/static/sql-createtable.html
- FK_REGEX = re.compile(
+ @lru_cache()
+ def _foreing_key_query(self, schema, has_filter_names, scope, kind):
+ pg_class_ref = pg_catalog.pg_class.alias("cls_ref")
+ pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref")
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ sql.case(
+ (
+ pg_catalog.pg_constraint.c.oid.is_not(None),
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid, True
+ ),
+ ),
+ else_=None,
+ ),
+ pg_namespace_ref.c.nspname,
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.contype == "f",
+ ),
+ )
+ .outerjoin(
+ pg_class_ref,
+ pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid,
+ )
+ .outerjoin(
+ pg_namespace_ref,
+ pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid,
+ )
+ .order_by(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
+
+ @util.memoized_property
+ def _fk_regex_pattern(self):
+ # https://www.postgresql.org/docs/14.0/static/sql-createtable.html
+ return re.compile(
r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
r"[\s]?(ON UPDATE "
@@ -4024,12 +3748,33 @@ class PGDialect(default.DefaultDialect):
r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
)
- t = sql.text(FK_SQL).columns(
- conname=sqltypes.Unicode, condef=sqltypes.Unicode
- )
- c = connection.execute(t, dict(table=table_oid))
- fkeys = []
- for conname, condef, conschema in c.fetchall():
+ def get_multi_foreign_keys(
+ self,
+ connection,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ postgresql_ignore_search_path=False,
+ **kw,
+ ):
+ preparer = self.identifier_preparer
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._foreing_key_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(query, params)
+
+ FK_REGEX = self._fk_regex_pattern
+
+ fkeys = defaultdict(list)
+ default = ReflectionDefaults.foreign_keys
+ for table_name, conname, condef, conschema in result:
+ # ensure that each table has an entry, even if it has
+ # no foreign keys
+ if conname is None:
+ fkeys[(schema, table_name)] = default()
+ continue
+ table_fks = fkeys[(schema, table_name)]
m = re.search(FK_REGEX, condef).groups()
(
@@ -4096,317 +3841,406 @@ class PGDialect(default.DefaultDialect):
"referred_columns": referred_columns,
"options": options,
}
- fkeys.append(fkey_d)
- return fkeys
-
- def _pg_index_any(self, col, compare_to):
- if self.server_version_info < (8, 1):
- # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us
- # "In CVS tip you could replace this with "attnum = ANY (indkey)".
- # Unfortunately, most array support doesn't work on int2vector in
- # pre-8.1 releases, so I think you're kinda stuck with the above
- # for now.
- # regards, tom lane"
- return "(%s)" % " OR ".join(
- "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10)
- )
- else:
- return "%s = ANY(%s)" % (col, compare_to)
+ table_fks.append(fkey_d)
+ return fkeys.items()
@reflection.cache
- def get_indexes(self, connection, table_name, schema, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ data = self.get_multi_indexes(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- # cast indkey as varchar since it's an int2vector,
- # returned as a list by some drivers such as pypostgresql
-
- if self.server_version_info < (8, 5):
- IDX_SQL = """
- SELECT
- i.relname as relname,
- ix.indisunique, ix.indexprs, ix.indpred,
- a.attname, a.attnum, NULL, ix.indkey%s,
- %s, %s, am.amname,
- NULL as indnkeyatts
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid = ix.indexrelid
- left outer join
- pg_attribute a
- on t.oid = a.attrelid and %s
- left outer join
- pg_am am
- on i.relam = am.oid
- WHERE
- t.relkind IN ('r', 'v', 'f', 'm')
- and t.oid = :table_oid
- and ix.indisprimary = 'f'
- ORDER BY
- t.relname,
- i.relname
- """ % (
- # version 8.3 here was based on observing the
- # cast does not work in PG 8.2.4, does work in 8.3.0.
- # nothing in PG changelogs regarding this.
- "::varchar" if self.server_version_info >= (8, 3) else "",
- "ix.indoption::varchar"
- if self.server_version_info >= (8, 3)
- else "NULL",
- "i.reloptions"
- if self.server_version_info >= (8, 2)
- else "NULL",
- self._pg_index_any("a.attnum", "ix.indkey"),
+ @util.memoized_property
+ def _index_query(self):
+ pg_class_index = pg_catalog.pg_class.alias("cls_idx")
+ # NOTE: repeating oids clause improve query performance
+
+ # subquery to get the columns
+ idx_sq = (
+ select(
+ pg_catalog.pg_index.c.indexrelid,
+ pg_catalog.pg_index.c.indrelid,
+ sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"),
+ sql.func.generate_subscripts(
+ pg_catalog.pg_index.c.indkey, 1
+ ).label("ord"),
)
- else:
- IDX_SQL = """
- SELECT
- i.relname as relname,
- ix.indisunique, ix.indexprs,
- a.attname, a.attnum, c.conrelid, ix.indkey::varchar,
- ix.indoption::varchar, i.reloptions, am.amname,
- pg_get_expr(ix.indpred, ix.indrelid),
- %s as indnkeyatts
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid = ix.indexrelid
- left outer join
- pg_attribute a
- on t.oid = a.attrelid and a.attnum = ANY(ix.indkey)
- left outer join
- pg_constraint c
- on (ix.indrelid = c.conrelid and
- ix.indexrelid = c.conindid and
- c.contype in ('p', 'u', 'x'))
- left outer join
- pg_am am
- on i.relam = am.oid
- WHERE
- t.relkind IN ('r', 'v', 'f', 'm', 'p')
- and t.oid = :table_oid
- and ix.indisprimary = 'f'
- ORDER BY
- t.relname,
- i.relname
- """ % (
- "ix.indnkeyatts"
- if self.server_version_info >= (11, 0)
- else "NULL",
+ .where(
+ ~pg_catalog.pg_index.c.indisprimary,
+ pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")),
)
+ .subquery("idx")
+ )
- t = sql.text(IDX_SQL).columns(
- relname=sqltypes.Unicode, attname=sqltypes.Unicode
+ attr_sq = (
+ select(
+ idx_sq.c.indexrelid,
+ idx_sq.c.indrelid,
+ pg_catalog.pg_attribute.c.attname,
+ )
+ .select_from(pg_catalog.pg_attribute)
+ .join(
+ idx_sq,
+ sql.and_(
+ pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum,
+ pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid,
+ ),
+ )
+ .where(idx_sq.c.indrelid.in_(bindparam("oids")))
+ .order_by(idx_sq.c.indexrelid, idx_sq.c.ord)
+ .subquery("idx_attr")
)
- c = connection.execute(t, dict(table_oid=table_oid))
- indexes = defaultdict(lambda: defaultdict(dict))
+ cols_sq = (
+ select(
+ attr_sq.c.indexrelid,
+ attr_sq.c.indrelid,
+ sql.func.array_agg(attr_sq.c.attname).label("cols"),
+ )
+ .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid)
+ .subquery("idx_cols")
+ )
- sv_idx_name = None
- for row in c.fetchall():
- (
- idx_name,
- unique,
- expr,
- col,
- col_num,
- conrelid,
- idx_key,
- idx_option,
- options,
- amname,
- filter_definition,
- indnkeyatts,
- ) = row
+ if self.server_version_info >= (11, 0):
+ indnkeyatts = pg_catalog.pg_index.c.indnkeyatts
+ else:
+ indnkeyatts = sql.null().label("indnkeyatts")
- if expr:
- if idx_name != sv_idx_name:
- util.warn(
- "Skipped unsupported reflection of "
- "expression-based index %s" % idx_name
- )
- sv_idx_name = idx_name
- continue
+ query = (
+ select(
+ pg_catalog.pg_index.c.indrelid,
+ pg_class_index.c.relname.label("relname_index"),
+ pg_catalog.pg_index.c.indisunique,
+ pg_catalog.pg_index.c.indexprs,
+ pg_catalog.pg_constraint.c.conrelid.is_not(None).label(
+ "has_constraint"
+ ),
+ pg_catalog.pg_index.c.indoption,
+ pg_class_index.c.reloptions,
+ pg_catalog.pg_am.c.amname,
+ pg_catalog.pg_get_expr(
+ pg_catalog.pg_index.c.indpred,
+ pg_catalog.pg_index.c.indrelid,
+ ).label("filter_definition"),
+ indnkeyatts,
+ cols_sq.c.cols.label("index_cols"),
+ )
+ .select_from(pg_catalog.pg_index)
+ .where(
+ pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")),
+ ~pg_catalog.pg_index.c.indisprimary,
+ )
+ .join(
+ pg_class_index,
+ pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid,
+ )
+ .join(
+ pg_catalog.pg_am,
+ pg_class_index.c.relam == pg_catalog.pg_am.c.oid,
+ )
+ .outerjoin(
+ cols_sq,
+ pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid,
+ )
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_index.c.indrelid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_index.c.indexrelid
+ == pg_catalog.pg_constraint.c.conindid,
+ pg_catalog.pg_constraint.c.contype
+ == sql.any_(_array.array(("p", "u", "x"))),
+ ),
+ )
+ .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname)
+ )
+ return query
- has_idx = idx_name in indexes
- index = indexes[idx_name]
- if col is not None:
- index["cols"][col_num] = col
- if not has_idx:
- idx_keys = idx_key.split()
- # "The number of key columns in the index, not counting any
- # included columns, which are merely stored and do not
- # participate in the index semantics"
- if indnkeyatts and idx_keys[indnkeyatts:]:
- # this is a "covering index" which has INCLUDE columns
- # as well as regular index columns
- inc_keys = idx_keys[indnkeyatts:]
- idx_keys = idx_keys[:indnkeyatts]
- else:
- inc_keys = []
+ def get_multi_indexes(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
- index["key"] = [int(k.strip()) for k in idx_keys]
- index["inc"] = [int(k.strip()) for k in inc_keys]
+ table_oids = self._get_table_oids(
+ connection, schema, filter_names, scope, kind, **kw
+ )
- # (new in pg 8.3)
- # "pg_index.indoption" is list of ints, one per column/expr.
- # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST
- sorting = {}
- for col_idx, col_flags in enumerate(
- (idx_option or "").split()
- ):
- col_flags = int(col_flags.strip())
- col_sorting = ()
- # try to set flags only if they differ from PG defaults...
- if col_flags & 0x01:
- col_sorting += ("desc",)
- if not (col_flags & 0x02):
- col_sorting += ("nulls_last",)
+ indexes = defaultdict(list)
+ default = ReflectionDefaults.indexes
+
+ batches = list(table_oids)
+
+ while batches:
+ batch = batches[0:3000]
+ batches[0:3000] = []
+
+ result = connection.execute(
+ self._index_query, {"oids": [r[0] for r in batch]}
+ ).mappings()
+
+ result_by_oid = defaultdict(list)
+ for row_dict in result:
+ result_by_oid[row_dict["indrelid"]].append(row_dict)
+
+ for oid, table_name in batch:
+ if oid not in result_by_oid:
+ # ensure that each table has an entry, even if reflection
+ # is skipped because not supported
+ indexes[(schema, table_name)] = default()
+ continue
+
+ for row in result_by_oid[oid]:
+ index_name = row["relname_index"]
+
+ table_indexes = indexes[(schema, table_name)]
+
+ if row["indexprs"]:
+ tn = (
+ table_name
+ if schema is None
+ else f"{schema}.{table_name}"
+ )
+ util.warn(
+ "Skipped unsupported reflection of "
+ f"expression-based index {index_name} of "
+ f"table {tn}"
+ )
+ continue
+
+ all_cols = row["index_cols"]
+ indnkeyatts = row["indnkeyatts"]
+ # "The number of key columns in the index, not counting any
+ # included columns, which are merely stored and do not
+ # participate in the index semantics"
+ if indnkeyatts and all_cols[indnkeyatts:]:
+ # this is a "covering index" which has INCLUDE columns
+ # as well as regular index columns
+ inc_cols = all_cols[indnkeyatts:]
+ idx_cols = all_cols[:indnkeyatts]
else:
- if col_flags & 0x02:
- col_sorting += ("nulls_first",)
- if col_sorting:
- sorting[col_idx] = col_sorting
- if sorting:
- index["sorting"] = sorting
-
- index["unique"] = unique
- if conrelid is not None:
- index["duplicates_constraint"] = idx_name
- if options:
- index["options"] = dict(
- [option.split("=") for option in options]
- )
-
- # it *might* be nice to include that this is 'btree' in the
- # reflection info. But we don't want an Index object
- # to have a ``postgresql_using`` in it that is just the
- # default, so for the moment leaving this out.
- if amname and amname != "btree":
- index["amname"] = amname
-
- if filter_definition:
- index["postgresql_where"] = filter_definition
+ idx_cols = all_cols
+ inc_cols = []
+
+ index = {
+ "name": index_name,
+ "unique": row["indisunique"],
+ "column_names": idx_cols,
+ }
+
+ sorting = {}
+ for col_index, col_flags in enumerate(row["indoption"]):
+ col_sorting = ()
+ # try to set flags only if they differ from PG
+ # defaults...
+ if col_flags & 0x01:
+ col_sorting += ("desc",)
+ if not (col_flags & 0x02):
+ col_sorting += ("nulls_last",)
+ else:
+ if col_flags & 0x02:
+ col_sorting += ("nulls_first",)
+ if col_sorting:
+ sorting[idx_cols[col_index]] = col_sorting
+ if sorting:
+ index["column_sorting"] = sorting
+ if row["has_constraint"]:
+ index["duplicates_constraint"] = index_name
+
+ dialect_options = {}
+ if row["reloptions"]:
+ dialect_options["postgresql_with"] = dict(
+ [option.split("=") for option in row["reloptions"]]
+ )
+ # it *might* be nice to include that this is 'btree' in the
+ # reflection info. But we don't want an Index object
+ # to have a ``postgresql_using`` in it that is just the
+ # default, so for the moment leaving this out.
+ amname = row["amname"]
+ if amname != "btree":
+ dialect_options["postgresql_using"] = row["amname"]
+ if row["filter_definition"]:
+ dialect_options["postgresql_where"] = row[
+ "filter_definition"
+ ]
+ if self.server_version_info >= (11, 0):
+ # NOTE: this is legacy, this is part of
+ # dialect_options now as of #7382
+ index["include_columns"] = inc_cols
+ dialect_options["postgresql_include"] = inc_cols
+ if dialect_options:
+ index["dialect_options"] = dialect_options
- result = []
- for name, idx in indexes.items():
- entry = {
- "name": name,
- "unique": idx["unique"],
- "column_names": [idx["cols"][i] for i in idx["key"]],
- }
- if self.server_version_info >= (11, 0):
- # NOTE: this is legacy, this is part of dialect_options now
- # as of #7382
- entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]]
- if "duplicates_constraint" in idx:
- entry["duplicates_constraint"] = idx["duplicates_constraint"]
- if "sorting" in idx:
- entry["column_sorting"] = dict(
- (idx["cols"][idx["key"][i]], value)
- for i, value in idx["sorting"].items()
- )
- if "include_columns" in entry:
- entry.setdefault("dialect_options", {})[
- "postgresql_include"
- ] = entry["include_columns"]
- if "options" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_with"
- ] = idx["options"]
- if "amname" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_using"
- ] = idx["amname"]
- if "postgresql_where" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_where"
- ] = idx["postgresql_where"]
- result.append(entry)
- return result
+ table_indexes.append(index)
+ return indexes.items()
@reflection.cache
def get_unique_constraints(
self, connection, table_name, schema=None, **kw
):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_unique_constraints(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- UNIQUE_SQL = """
- SELECT
- cons.conname as name,
- cons.conkey as key,
- a.attnum as col_num,
- a.attname as col_name
- FROM
- pg_catalog.pg_constraint cons
- join pg_attribute a
- on cons.conrelid = a.attrelid AND
- a.attnum = ANY(cons.conkey)
- WHERE
- cons.conrelid = :table_oid AND
- cons.contype = 'u'
- """
-
- t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
+ def get_multi_unique_constraints(
+ self,
+ connection,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ **kw,
+ ):
+ result = self._reflect_constraint(
+ connection, "u", schema, filter_names, scope, kind, **kw
+ )
- uniques = defaultdict(lambda: defaultdict(dict))
- for row in c.fetchall():
- uc = uniques[row.name]
- uc["key"] = row.key
- uc["cols"][row.col_num] = row.col_name
+ # each table can have multiple unique constraints
+ uniques = defaultdict(list)
+ default = ReflectionDefaults.unique_constraints
+ for (table_name, cols, con_name) in result:
+ # ensure a list is created for each table. leave it empty if
+ # the table has no unique cosntraint
+ if con_name is None:
+ uniques[(schema, table_name)] = default()
+ continue
- return [
- {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]}
- for name, uc in uniques.items()
- ]
+ uniques[(schema, table_name)].append(
+ {
+ "column_names": cols,
+ "name": con_name,
+ }
+ )
+ return uniques.items()
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_table_comment(
+ connection,
+ schema,
+ [table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- COMMENT_SQL = """
- SELECT
- pgd.description as table_comment
- FROM
- pg_catalog.pg_description pgd
- WHERE
- pgd.objsubid = 0 AND
- pgd.objoid = :table_oid
- """
+ @lru_cache()
+ def _comment_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_description.c.description,
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_description,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_description.c.objoid,
+ pg_catalog.pg_description.c.objsubid == 0,
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
- c = connection.execute(
- sql.text(COMMENT_SQL), dict(table_oid=table_oid)
+ def get_multi_table_comment(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._comment_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(query, params)
+
+ default = ReflectionDefaults.table_comment
+ return (
+ (
+ (schema, table),
+ {"text": comment} if comment is not None else default(),
+ )
+ for table, comment in result
)
- return {"text": c.scalar()}
@reflection.cache
def get_check_constraints(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_check_constraints(
+ connection,
+ schema,
+ [table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- CHECK_SQL = """
- SELECT
- cons.conname as name,
- pg_get_constraintdef(cons.oid) as src
- FROM
- pg_catalog.pg_constraint cons
- WHERE
- cons.conrelid = :table_oid AND
- cons.contype = 'c'
- """
-
- c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid))
+ @lru_cache()
+ def _check_constraint_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ sql.case(
+ (
+ pg_catalog.pg_constraint.c.oid.is_not(None),
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid
+ ),
+ ),
+ else_=None,
+ ),
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.contype == "c",
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
- ret = []
- for name, src in c:
+ def get_multi_check_constraints(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._check_constraint_query(
+ schema, has_filter_names, scope, kind
+ )
+ result = connection.execute(query, params)
+
+ check_constraints = defaultdict(list)
+ default = ReflectionDefaults.check_constraints
+ for table_name, check_name, src in result:
+ # only two cases for check_name and src: both null or both defined
+ if check_name is None and src is None:
+ check_constraints[(schema, table_name)] = default()
+ continue
# samples:
# "CHECK (((a > 1) AND (a < 5)))"
# "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))"
@@ -4424,84 +4258,118 @@ class PGDialect(default.DefaultDialect):
sqltext = re.compile(
r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL
).sub(r"\1", m.group(1))
- entry = {"name": name, "sqltext": sqltext}
+ entry = {"name": check_name, "sqltext": sqltext}
if m and m.group(2):
entry["dialect_options"] = {"not_valid": True}
- ret.append(entry)
- return ret
-
- def _load_enums(self, connection, schema=None):
- schema = schema or self.default_schema_name
- if not self.supports_native_enum:
- return {}
-
- # Load data types for enums:
- SQL_ENUMS = """
- SELECT t.typname as "name",
- -- no enum defaults in 8.4 at least
- -- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema",
- e.enumlabel as "label"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid
- WHERE t.typtype = 'e'
- """
+ check_constraints[(schema, table_name)].append(entry)
+ return check_constraints.items()
- if schema != "*":
- SQL_ENUMS += "AND n.nspname = :schema "
+ @lru_cache()
+ def _enum_query(self, schema):
+ lbl_sq = (
+ select(
+ pg_catalog.pg_enum.c.enumtypid, pg_catalog.pg_enum.c.enumlabel
+ )
+ .order_by(
+ pg_catalog.pg_enum.c.enumtypid,
+ pg_catalog.pg_enum.c.enumsortorder,
+ )
+ .subquery("lbl")
+ )
- # e.oid gives us label order within an enum
- SQL_ENUMS += 'ORDER BY "schema", "name", e.oid'
+ lbl_agg_sq = (
+ select(
+ lbl_sq.c.enumtypid,
+ sql.func.array_agg(lbl_sq.c.enumlabel).label("labels"),
+ )
+ .group_by(lbl_sq.c.enumtypid)
+ .subquery("lbl_agg")
+ )
- s = sql.text(SQL_ENUMS).columns(
- attname=sqltypes.Unicode, label=sqltypes.Unicode
+ query = (
+ select(
+ pg_catalog.pg_type.c.typname.label("name"),
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label(
+ "visible"
+ ),
+ pg_catalog.pg_namespace.c.nspname.label("schema"),
+ lbl_agg_sq.c.labels.label("labels"),
+ )
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
+ )
+ .outerjoin(
+ lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid
+ )
+ .where(pg_catalog.pg_type.c.typtype == "e")
+ .order_by(
+ pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname
+ )
)
- if schema != "*":
- s = s.bindparams(schema=schema)
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
+ @reflection.cache
+ def _load_enums(self, connection, schema=None, **kw):
+ if not self.supports_native_enum:
+ return []
- c = connection.execute(s)
+ result = connection.execute(self._enum_query(schema))
enums = []
- enum_by_name = {}
- for enum in c.fetchall():
- key = (enum.schema, enum.name)
- if key in enum_by_name:
- enum_by_name[key]["labels"].append(enum.label)
- else:
- enum_by_name[key] = enum_rec = {
- "name": enum.name,
- "schema": enum.schema,
- "visible": enum.visible,
- "labels": [],
+ for name, visible, schema, labels in result:
+ enums.append(
+ {
+ "name": name,
+ "schema": schema,
+ "visible": visible,
+ "labels": [] if labels is None else labels,
}
- if enum.label is not None:
- enum_rec["labels"].append(enum.label)
- enums.append(enum_rec)
+ )
return enums
- def _load_domains(self, connection):
- # Load data types for domains:
- SQL_DOMAINS = """
- SELECT t.typname as "name",
- pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
- not t.typnotnull as "nullable",
- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- WHERE t.typtype = 'd'
- """
+ @util.memoized_property
+ def _domain_query(self):
+ return (
+ select(
+ pg_catalog.pg_type.c.typname.label("name"),
+ pg_catalog.format_type(
+ pg_catalog.pg_type.c.typbasetype,
+ pg_catalog.pg_type.c.typtypmod,
+ ).label("attype"),
+ (~pg_catalog.pg_type.c.typnotnull).label("nullable"),
+ pg_catalog.pg_type.c.typdefault.label("default"),
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label(
+ "visible"
+ ),
+ pg_catalog.pg_namespace.c.nspname.label("schema"),
+ )
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
+ )
+ .where(pg_catalog.pg_type.c.typtype == "d")
+ )
- s = sql.text(SQL_DOMAINS)
- c = connection.execution_options(future_result=True).execute(s)
+ @reflection.cache
+ def _load_domains(self, connection, **kw):
+ # Load data types for domains:
+ result = connection.execute(self._domain_query)
domains = {}
- for domain in c.mappings():
+ for domain in result.mappings():
domain = domain
# strip (30) from character varying(30)
attype = re.search(r"([^\(]+)", domain["attype"]).group(1)