diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/asyncmy.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/__init__.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/_psycopg_common.py | 188 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg.py | 641 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 163 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/create.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/cursor.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 19 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/url.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 42 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/config.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/plugin/plugin_base.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_results.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 8 |
19 files changed, 948 insertions, 179 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index b59571460..8092e99ff 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -175,7 +175,7 @@ class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor): class AsyncAdapt_asyncmy_connection(AdaptedConnection): await_ = staticmethod(await_only) - __slots__ = ("dbapi", "_connection", "_execute_mutex") + __slots__ = ("dbapi", "_execute_mutex") def __init__(self, dbapi, connection): self.dbapi = dbapi diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 08b05dc74..b0af5395e 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -4,9 +4,12 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from types import ModuleType + from . import asyncpg # noqa from . import base from . import pg8000 # noqa +from . import psycopg # noqa from . import psycopg2 # noqa from . import psycopg2cffi # noqa from .array import All @@ -58,6 +61,11 @@ from .ranges import TSRANGE from .ranges import TSTZRANGE from ...util import compat +# Alias psycopg also as psycopg_asnyc +psycopg_async = type( + "psycopg_asnyc", (ModuleType,), {"dialect": psycopg.dialect_async} +) + base.dialect = dialect = psycopg2.dialect diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py new file mode 100644 index 000000000..d82d5f009 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -0,0 +1,188 @@ +import decimal + +from .array import ARRAY as PGARRAY +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import PGDialect +from .base import PGExecutionContext +from .base import UUID +from .hstore import HSTORE +from ... import exc +from ... import processors +from ... import types as sqltypes +from ... import util + +_server_side_id = util.counter() + + +class _PsycopgNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # psycopg returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # psycopg returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PsycopgHStore(HSTORE): + def bind_processor(self, dialect): + if dialect._has_native_hstore: + return None + else: + return super(_PsycopgHStore, self).bind_processor(dialect) + + def result_processor(self, dialect, coltype): + if dialect._has_native_hstore: + return None + else: + return super(_PsycopgHStore, self).result_processor( + dialect, coltype + ) + + +class _PsycopgUUID(UUID): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if not self.as_uuid and dialect.use_native_uuid: + + def process(value): + if value is not None: + value = str(value) + return value + + return process + + +class _PsycopgARRAY(PGARRAY): + render_bind_cast = True + + +class _PGExecutionContext_common_psycopg(PGExecutionContext): + def create_server_side_cursor(self): + # use server-side cursors: + # psycopg + # https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#server-side-cursors + # psycopg2 + # https://www.psycopg.org/docs/usage.html#server-side-cursors + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return self._dbapi_connection.cursor(ident) + + +class _PGDialect_common_psycopg(PGDialect): + supports_statement_cache = True + supports_server_side_cursors = True + + default_paramstyle = "pyformat" + + _has_native_hstore = True + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: _PsycopgNumeric, + HSTORE: _PsycopgHStore, + UUID: _PsycopgUUID, + sqltypes.ARRAY: _PsycopgARRAY, + }, + ) + + def __init__( + self, + client_encoding=None, + use_native_hstore=True, + use_native_uuid=True, + **kwargs + ): + PGDialect.__init__(self, **kwargs) + if not use_native_hstore: + self._has_native_hstore = False + self.use_native_hstore = use_native_hstore + self.use_native_uuid = use_native_uuid + self.client_encoding = client_encoding + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user", database="dbname") + + is_multihost = False + if "host" in url.query: + is_multihost = isinstance(url.query["host"], (list, tuple)) + + if opts: + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + if is_multihost: + opts["host"] = ",".join(url.query["host"]) + # send individual dbname, user, password, host, port + # parameters to psycopg2.connect() + return ([], opts) + elif url.query: + # any other connection arguments, pass directly + opts.update(url.query) + if is_multihost: + opts["host"] = ",".join(url.query["host"]) + return ([], opts) + else: + # no connection arguments whatsoever; psycopg2.connect() + # requires that "dsn" be present as a blank string. + return ([""], opts) + + def get_isolation_level_values(self, dbapi_conn): + return ( + "AUTOCOMMIT", + "READ COMMITTED", + "READ UNCOMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", + ) + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def _do_autocommit(self, connection, value): + connection.autocommit = value + + def do_ping(self, dbapi_connection): + cursor = None + try: + self._do_autocommit(dbapi_connection, True) + cursor = dbapi_connection.cursor() + try: + cursor.execute(self._dialect_specific_select_one) + finally: + cursor.close() + if not dbapi_connection.closed: + self._do_autocommit(dbapi_connection, False) + except self.dbapi.Error as err: + if self.is_disconnect(err, dbapi_connection, cursor): + return False + else: + raise + else: + return True diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 5ef9df800..1fdb46b6f 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -553,7 +553,6 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): class AsyncAdapt_asyncpg_connection(AdaptedConnection): __slots__ = ( "dbapi", - "_connection", "isolation_level", "_isolation_setting", "readonly", diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 008398865..d1d881dc3 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1701,7 +1701,7 @@ class UUID(sqltypes.TypeEngine): or as Python uuid objects. The UUID type is currently known to work within the prominent DBAPI - drivers supported by SQLAlchemy including psycopg2, pg8000 and + drivers supported by SQLAlchemy including psycopg, psycopg2, pg8000 and asyncpg. Support for other DBAPI drivers may be incomplete or non-present. """ @@ -1992,6 +1992,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): self.connection.execute(DropEnumType(enum)) + def get_dbapi_type(self, dbapi): + """dont return dbapi.STRING for ENUM in PostgreSQL, since that's + a different type""" + + return None + def _check_for_name_in_memos(self, checkfirst, kw): """Look in the 'ddl runner' for 'memos', then note our name in that collection. diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py new file mode 100644 index 000000000..c2017c975 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -0,0 +1,641 @@ +# postgresql/psycopg2.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: postgresql+psycopg + :name: psycopg (a.k.a. psycopg 3) + :dbapi: psycopg + :connectstring: postgresql+psycopg://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg/ + +``psycopg`` is the package and module name for version 3 of the ``psycopg`` +database driver, formerly known as ``psycopg2``. This driver is different +enough from its ``psycopg2`` predecessor that SQLAlchemy supports it +via a totally separate dialect; support for ``psycopg2`` is expected to remain +for as long as that package continues to function for modern Python versions, +and also remains the default dialect for the ``postgresql://`` dialect +series. + +The SQLAlchemy ``psycopg`` dialect provides both a sync and an async +implementation under the same dialect name. The proper version is +selected depending on how the engine is created: + +* calling :func:`_sa.create_engine` with ``postgresql+psycopg://...`` will + automatically select the sync version, e.g.:: + + from sqlalchemy import create_engine + sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test") + +* calling :func:`_asyncio.create_async_engine` with + ``postgresql+psycopg://...`` will automatically select the async version, + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test") + +The asyncio version of the dialect may also be specified explicitly using the +``psycopg_async`` suffix, as:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test") + +The ``psycopg`` dialect has the same API features as that of ``psycopg2``, +with the exeption of the "fast executemany" helpers. The "fast executemany" +helpers are expected to be generalized and ported to ``psycopg`` before the final +release of SQLAlchemy 2.0, however. + + +.. seealso:: + + :ref:`postgresql_psycopg2` - The SQLAlchemy ``psycopg`` + dialect shares most of its behavior with the ``psycopg2`` dialect. + Further documentation is available there. + +""" # noqa +import logging +import re + +from ._psycopg_common import _PGDialect_common_psycopg +from ._psycopg_common import _PGExecutionContext_common_psycopg +from ._psycopg_common import _PsycopgUUID +from .base import INTERVAL +from .base import PGCompiler +from .base import PGIdentifierPreparer +from .base import UUID +from .json import JSON +from .json import JSONB +from .json import JSONPathType +from ... import pool +from ... import types as sqltypes +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + +logger = logging.getLogger("sqlalchemy.dialects.postgresql") + + +class _PGString(sqltypes.String): + render_bind_cast = True + + +class _PGJSON(JSON): + render_bind_cast = True + + def bind_processor(self, dialect): + return self._make_bind_processor(None, dialect._psycopg_Json) + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + render_bind_cast = True + + def bind_processor(self, dialect): + return self._make_bind_processor(None, dialect._psycopg_Jsonb) + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + __visit_name__ = "json_int_index" + + render_bind_cast = True + + +class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + __visit_name__ = "json_str_index" + + render_bind_cast = True + + +class _PGJSONPathType(JSONPathType): + pass + + +class _PGUUID(_PsycopgUUID): + render_bind_cast = True + + +class _PGInterval(INTERVAL): + render_bind_cast = True + + +class _PGTimeStamp(sqltypes.DateTime): + render_bind_cast = True + + +class _PGDate(sqltypes.Date): + render_bind_cast = True + + +class _PGTime(sqltypes.Time): + render_bind_cast = True + + +class _PGInteger(sqltypes.Integer): + render_bind_cast = True + + +class _PGSmallInteger(sqltypes.SmallInteger): + render_bind_cast = True + + +class _PGNullType(sqltypes.NullType): + render_bind_cast = True + + +class _PGBigInteger(sqltypes.BigInteger): + render_bind_cast = True + + +class _PGBoolean(sqltypes.Boolean): + render_bind_cast = True + + +class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg): + pass + + +class PGCompiler_psycopg(PGCompiler): + pass + + +class PGIdentifierPreparer_psycopg(PGIdentifierPreparer): + pass + + +def _log_notices(diagnostic): + logger.info("%s: %s", diagnostic.severity, diagnostic.message_primary) + + +class PGDialect_psycopg(_PGDialect_common_psycopg): + driver = "psycopg" + + supports_statement_cache = True + supports_server_side_cursors = True + default_paramstyle = "pyformat" + supports_sane_multi_rowcount = True + + execution_ctx_cls = PGExecutionContext_psycopg + statement_compiler = PGCompiler_psycopg + preparer = PGIdentifierPreparer_psycopg + psycopg_version = (0, 0) + + _has_native_hstore = True + + colspecs = util.update_copy( + _PGDialect_common_psycopg.colspecs, + { + sqltypes.String: _PGString, + JSON: _PGJSON, + sqltypes.JSON: _PGJSON, + JSONB: _PGJSONB, + sqltypes.JSON.JSONPathType: _PGJSONPathType, + sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType, + UUID: _PGUUID, + sqltypes.Interval: _PGInterval, + INTERVAL: _PGInterval, + sqltypes.Date: _PGDate, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Time: _PGTime, + sqltypes.Integer: _PGInteger, + sqltypes.SmallInteger: _PGSmallInteger, + sqltypes.BigInteger: _PGBigInteger, + }, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + if self.dbapi: + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + self.psycopg_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + + if self.psycopg_version < (3, 0, 2): + raise ImportError( + "psycopg version 3.0.2 or higher is required." + ) + + from psycopg.adapt import AdaptersMap + + self._psycopg_adapters_map = adapters_map = AdaptersMap( + self.dbapi.adapters + ) + + if self._json_deserializer: + from psycopg.types.json import set_json_loads + + set_json_loads(self._json_deserializer, adapters_map) + + if self._json_serializer: + from psycopg.types.json import set_json_dumps + + set_json_dumps(self._json_serializer, adapters_map) + + def create_connect_args(self, url): + # see https://github.com/psycopg/psycopg/issues/83 + cargs, cparams = super().create_connect_args(url) + + cparams["context"] = self._psycopg_adapters_map + if self.client_encoding is not None: + cparams["client_encoding"] = self.client_encoding + return cargs, cparams + + def _type_info_fetch(self, connection, name): + from psycopg.types import TypeInfo + + return TypeInfo.fetch(connection.connection, name) + + def initialize(self, connection): + super().initialize(connection) + + # PGDialect.initialize() checks server version for <= 8.2 and sets + # this flag to False if so + if not self.full_returning: + self.insert_executemany_returning = False + + # HSTORE can't be registered until we have a connection so that + # we can look up its OID, so we set up this adapter in + # initialize() + if self.use_native_hstore: + info = self._type_info_fetch(connection, "hstore") + self._has_native_hstore = info is not None + if self._has_native_hstore: + from psycopg.types.hstore import register_hstore + + # register the adapter for connections made subsequent to + # this one + register_hstore(info, self._psycopg_adapters_map) + + # register the adapter for this connection + register_hstore(info, connection.connection) + + @classmethod + def dbapi(cls): + import psycopg + + return psycopg + + @classmethod + def get_async_dialect_cls(cls, url): + return PGDialectAsync_psycopg + + @util.memoized_property + def _isolation_lookup(self): + return { + "READ COMMITTED": self.dbapi.IsolationLevel.READ_COMMITTED, + "READ UNCOMMITTED": self.dbapi.IsolationLevel.READ_UNCOMMITTED, + "REPEATABLE READ": self.dbapi.IsolationLevel.REPEATABLE_READ, + "SERIALIZABLE": self.dbapi.IsolationLevel.SERIALIZABLE, + } + + @util.memoized_property + def _psycopg_Json(self): + from psycopg.types import json + + return json.Json + + @util.memoized_property + def _psycopg_Jsonb(self): + from psycopg.types import json + + return json.Jsonb + + @util.memoized_property + def _psycopg_TransactionStatus(self): + from psycopg.pq import TransactionStatus + + return TransactionStatus + + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.autocommit = autocommit + connection.isolation_level = isolation_level + + def get_isolation_level(self, dbapi_connection): + if hasattr(dbapi_connection, "dbapi_connection"): + dbapi_connection = dbapi_connection.dbapi_connection + + status_before = dbapi_connection.info.transaction_status + value = super().get_isolation_level(dbapi_connection) + + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + if status_before == self._psycopg_TransactionStatus.IDLE: + dbapi_connection.rollback() + return value + + def set_isolation_level(self, connection, level): + connection = getattr(connection, "dbapi_connection", connection) + if level == "AUTOCOMMIT": + self._do_isolation_level( + connection, autocommit=True, isolation_level=None + ) + else: + self._do_isolation_level( + connection, + autocommit=False, + isolation_level=self._isolation_lookup[level], + ) + + def set_readonly(self, connection, value): + connection.read_only = value + + def get_readonly(self, connection): + return connection.read_only + + def on_connect(self): + def notices(conn): + conn.add_notice_handler(_log_notices) + + fns = [notices] + + if self.isolation_level is not None: + + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + fns.append(on_connect) + + # fns always has the notices function + def on_connect(conn): + for fn in fns: + fn(conn) + + return on_connect + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error) and connection is not None: + if connection.closed or connection.broken: + return True + return False + + def _do_prepared_twophase(self, connection, command, recover=False): + dbapi_conn = connection.connection.dbapi_connection + if ( + recover + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + or dbapi_conn.info.transaction_status + != self._psycopg_TransactionStatus.IDLE + ): + dbapi_conn.rollback() + before = dbapi_conn.autocommit + try: + self._do_autocommit(dbapi_conn, True) + dbapi_conn.execute(command) + finally: + self._do_autocommit(dbapi_conn, before) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + self._do_prepared_twophase( + connection, f"ROLLBACK PREPARED '{xid}'", recover=recover + ) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + self._do_prepared_twophase( + connection, f"COMMIT PREPARED '{xid}'", recover=recover + ) + else: + self.do_commit(connection.connection) + + +class AsyncAdapt_psycopg_cursor: + __slots__ = ("_cursor", "await_", "_rows") + + _psycopg_ExecStatus = None + + def __init__(self, cursor, await_) -> None: + self._cursor = cursor + self.await_ = await_ + self._rows = [] + + def __getattr__(self, name): + return getattr(self._cursor, name) + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + def close(self): + self._rows.clear() + # Normal cursor just call _close() in a non-sync way. + self._cursor._close() + + def execute(self, query, params=None, **kw): + result = self.await_(self._cursor.execute(query, params, **kw)) + # sqlalchemy result is not async, so need to pull all rows here + res = self._cursor.pgresult + + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + if res and res.status == self._psycopg_ExecStatus.TUPLES_OK: + rows = self.await_(self._cursor.fetchall()) + if not isinstance(rows, list): + self._rows = list(rows) + else: + self._rows = rows + return result + + def executemany(self, query, params_seq): + return self.await_(self._cursor.executemany(query, params_seq)) + + def __iter__(self): + # TODO: try to avoid pop(0) on a list + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + # TODO: try to avoid pop(0) on a list + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self._cursor.arraysize + + retval = self._rows[0:size] + self._rows = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows + self._rows = [] + return retval + + +class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor): + def execute(self, query, params=None, **kw): + self.await_(self._cursor.execute(query, params, **kw)) + return self + + def close(self): + self.await_(self._cursor.close()) + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=0): + return self.await_(self._cursor.fetchmany(size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + def __iter__(self): + iterator = self._cursor.__aiter__() + while True: + try: + yield self.await_(iterator.__anext__()) + except StopAsyncIteration: + break + + +class AsyncAdapt_psycopg_connection(AdaptedConnection): + __slots__ = () + await_ = staticmethod(await_only) + + def __init__(self, connection) -> None: + self._connection = connection + + def __getattr__(self, name): + return getattr(self._connection, name) + + def execute(self, query, params=None, **kw): + cursor = self.await_(self._connection.execute(query, params, **kw)) + return AsyncAdapt_psycopg_cursor(cursor, self.await_) + + def cursor(self, *args, **kw): + cursor = self._connection.cursor(*args, **kw) + if hasattr(cursor, "name"): + return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_) + else: + return AsyncAdapt_psycopg_cursor(cursor, self.await_) + + def commit(self): + self.await_(self._connection.commit()) + + def rollback(self): + self.await_(self._connection.rollback()) + + def close(self): + self.await_(self._connection.close()) + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + self.set_autocommit(value) + + def set_autocommit(self, value): + self.await_(self._connection.set_autocommit(value)) + + def set_isolation_level(self, value): + self.await_(self._connection.set_isolation_level(value)) + + def set_read_only(self, value): + self.await_(self._connection.set_read_only(value)) + + def set_deferrable(self, value): + self.await_(self._connection.set_deferrable(value)) + + +class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection): + __slots__ = () + await_ = staticmethod(await_fallback) + + +class PsycopgAdaptDBAPI: + def __init__(self, psycopg) -> None: + self.psycopg = psycopg + + for k, v in self.psycopg.__dict__.items(): + if k != "connect": + self.__dict__[k] = v + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + if util.asbool(async_fallback): + return AsyncAdaptFallback_psycopg_connection( + await_fallback( + self.psycopg.AsyncConnection.connect(*arg, **kw) + ) + ) + else: + return AsyncAdapt_psycopg_connection( + await_only(self.psycopg.AsyncConnection.connect(*arg, **kw)) + ) + + +class PGDialectAsync_psycopg(PGDialect_psycopg): + is_async = True + supports_statement_cache = True + + @classmethod + def dbapi(cls): + import psycopg + from psycopg.pq import ExecStatus + + AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus + + return PsycopgAdaptDBAPI(psycopg) + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def _type_info_fetch(self, connection, name): + from psycopg.types import TypeInfo + + adapted = connection.connection + return adapted.await_(TypeInfo.fetch(adapted._connection, name)) + + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.set_autocommit(autocommit) + connection.set_isolation_level(isolation_level) + + def _do_autocommit(self, connection, value): + connection.set_autocommit(value) + + def set_readonly(self, connection, value): + connection.set_read_only(value) + + def set_deferrable(self, connection, value): + connection.set_deferrable(value) + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = PGDialect_psycopg +dialect_async = PGDialectAsync_psycopg diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 39b6e0ed1..3d9f90a29 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -11,6 +11,8 @@ r""" :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...] :url: https://pypi.org/project/psycopg2/ +.. _psycopg2_toplevel: + psycopg2 Connect Arguments -------------------------- @@ -442,25 +444,15 @@ which may be more performant. """ # noqa import collections.abc as collections_abc -import decimal import logging import re -from uuid import UUID as _python_UUID -from .array import ARRAY as PGARRAY -from .base import _DECIMAL_TYPES -from .base import _FLOAT_TYPES -from .base import _INT_TYPES +from ._psycopg_common import _PGDialect_common_psycopg +from ._psycopg_common import _PGExecutionContext_common_psycopg from .base import PGCompiler -from .base import PGDialect -from .base import PGExecutionContext from .base import PGIdentifierPreparer -from .base import UUID -from .hstore import HSTORE from .json import JSON from .json import JSONB -from ... import exc -from ... import processors from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor @@ -469,53 +461,6 @@ from ...engine import cursor as _cursor logger = logging.getLogger("sqlalchemy.dialects.postgresql") -class _PGNumeric(sqltypes.Numeric): - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect, coltype): - if self.asdecimal: - if coltype in _FLOAT_TYPES: - return processors.to_decimal_processor_factory( - decimal.Decimal, self._effective_decimal_return_scale - ) - elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: - # pg8000 returns Decimal natively for 1700 - return None - else: - raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype - ) - else: - if coltype in _FLOAT_TYPES: - # pg8000 returns float natively for 701 - return None - elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: - return processors.to_float - else: - raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype - ) - - -class _PGHStore(HSTORE): - def bind_processor(self, dialect): - if dialect._has_native_hstore: - return None - else: - return super(_PGHStore, self).bind_processor(dialect) - - def result_processor(self, dialect, coltype): - if dialect._has_native_hstore: - return None - else: - return super(_PGHStore, self).result_processor(dialect, coltype) - - -class _PGARRAY(PGARRAY): - render_bind_cast = True - - class _PGJSON(JSON): def result_processor(self, dialect, coltype): return None @@ -526,40 +471,9 @@ class _PGJSONB(JSONB): return None -class _PGUUID(UUID): - def bind_processor(self, dialect): - if not self.as_uuid and dialect.use_native_uuid: - - def process(value): - if value is not None: - value = _python_UUID(value) - return value - - return process - - def result_processor(self, dialect, coltype): - if not self.as_uuid and dialect.use_native_uuid: - - def process(value): - if value is not None: - value = str(value) - return value - - return process - - -_server_side_id = util.counter() - - -class PGExecutionContext_psycopg2(PGExecutionContext): +class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg): _psycopg2_fetched_rows = None - def create_server_side_cursor(self): - # use server-side cursors: - # https://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) - return self._dbapi_connection.cursor(ident) - def post_exec(self): if ( self._psycopg2_fetched_rows @@ -614,7 +528,7 @@ EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol( ) -class PGDialect_psycopg2(PGDialect): +class PGDialect_psycopg2(_PGDialect_common_psycopg): driver = "psycopg2" supports_statement_cache = True @@ -631,34 +545,22 @@ class PGDialect_psycopg2(PGDialect): _has_native_hstore = True colspecs = util.update_copy( - PGDialect.colspecs, + _PGDialect_common_psycopg.colspecs, { - sqltypes.Numeric: _PGNumeric, - HSTORE: _PGHStore, JSON: _PGJSON, sqltypes.JSON: _PGJSON, JSONB: _PGJSONB, - UUID: _PGUUID, - sqltypes.ARRAY: _PGARRAY, }, ) def __init__( self, - client_encoding=None, - use_native_hstore=True, - use_native_uuid=True, executemany_mode="values_only", executemany_batch_page_size=100, executemany_values_page_size=1000, **kwargs ): - PGDialect.__init__(self, **kwargs) - if not use_native_hstore: - self._has_native_hstore = False - self.use_native_hstore = use_native_hstore - self.use_native_uuid = use_native_uuid - self.client_encoding = client_encoding + _PGDialect_common_psycopg.__init__(self, **kwargs) # Parse executemany_mode argument, allowing it to be only one of the # symbol names @@ -737,9 +639,6 @@ class PGDialect_psycopg2(PGDialect): "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE, } - def get_isolation_level_values(self, dbapi_conn): - return list(self._isolation_lookup) - def set_isolation_level(self, connection, level): connection.set_isolation_level(self._isolation_lookup[level]) @@ -755,25 +654,6 @@ class PGDialect_psycopg2(PGDialect): def get_deferrable(self, connection): return connection.deferrable - def do_ping(self, dbapi_connection): - cursor = None - try: - dbapi_connection.autocommit = True - cursor = dbapi_connection.cursor() - try: - cursor.execute(self._dialect_specific_select_one) - finally: - cursor.close() - if not dbapi_connection.closed: - dbapi_connection.autocommit = False - except self.dbapi.Error as err: - if self.is_disconnect(err, dbapi_connection, cursor): - return False - else: - raise - else: - return True - def on_connect(self): extras = self._psycopg2_extras @@ -911,33 +791,6 @@ class PGDialect_psycopg2(PGDialect): else: return None - def create_connect_args(self, url): - opts = url.translate_connect_args(username="user") - - is_multihost = False - if "host" in url.query: - is_multihost = isinstance(url.query["host"], (list, tuple)) - - if opts: - if "port" in opts: - opts["port"] = int(opts["port"]) - opts.update(url.query) - if is_multihost: - opts["host"] = ",".join(url.query["host"]) - # send individual dbname, user, password, host, port - # parameters to psycopg2.connect() - return ([], opts) - elif url.query: - # any other connection arguments, pass directly - opts.update(url.query) - if is_multihost: - opts["host"] = ",".join(url.query["host"]) - return ([], opts) - else: - # no connection arguments whatsoever; psycopg2.connect() - # requires that "dsn" be present as a blank string. - return ([""], opts) - def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.Error): # check the "closed" flag. this might not be diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index f9a65a0f8..8fcba7503 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -458,7 +458,11 @@ def create_engine(url, **kwargs): u, plugins, kwargs = u._instantiate_plugins(kwargs) entrypoint = u._get_entrypoint() - dialect_cls = entrypoint.get_dialect_cls(u) + _is_async = kwargs.pop("_is_async", False) + if _is_async: + dialect_cls = entrypoint.get_async_dialect_cls(u) + else: + dialect_cls = entrypoint.get_dialect_cls(u) if kwargs.pop("_coerce_config", False): diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 1f1a2fcf1..7f2b8b412 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1223,6 +1223,7 @@ class BaseCursorResult: """ + if (not hard and self._soft_closed) or (hard and self.closed): return diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index cb04eb525..64500b41b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1284,6 +1284,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): result.out_parameters = out_parameters def _setup_dml_or_text_result(self): + if self.isinsert: if self.compiled.postfetch_lastrowid: self.inserted_primary_key_rows = ( @@ -1332,8 +1333,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # assert not result.returns_rows elif self.isupdate and self._is_implicit_returning: + # get rowcount + # (which requires open cursor on some drivers) + # we were not doing this in 1.4, however + # test_rowcount -> test_update_rowcount_return_defaults + # is testing this, and psycopg will no longer return + # rowcount after cursor is closed. + result.rowcount + row = result.fetchone() self.returned_default_rows = [row] + result._soft_close() # test that it has a cursor metadata that is accurate. @@ -1410,6 +1420,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): dialect = self.dialect + # all of the rest of this... cython? + if dialect._has_events: inputsizes = dict(inputsizes) dialect.dispatch.do_setinputsizes( diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 251d01c5e..faaf073ab 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -1114,6 +1114,25 @@ class Dialect: return cls @classmethod + def get_async_dialect_cls(cls, url): + """Given a URL, return the :class:`.Dialect` that will be used by + an async engine. + + By default this is an alias of :meth:`.Dialect.get_dialect_cls` and + just returns the cls. It may be used if a dialect provides + both a sync and async version under the same name, like the + ``psycopg`` driver. + + .. versionadded:: 2 + + .. seealso:: + + :meth:`.Dialect.get_dialect_cls` + + """ + return cls.get_dialect_cls(url) + + @classmethod def load_provisioning(cls): """set up the provision.py module for this dialect. diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index c83753bdc..7cdf25c21 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -655,13 +655,16 @@ class URL( else: return cls - def get_dialect(self): + def get_dialect(self, _is_async=False): """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding to this URL's driver name. """ entrypoint = self._get_entrypoint() - dialect_cls = entrypoint.get_dialect_cls(self) + if _is_async: + dialect_cls = entrypoint.get_async_dialect_cls(self) + else: + dialect_cls = entrypoint.get_dialect_cls(self) return dialect_cls def translate_connect_args(self, names=None, **kw): diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 221d82f08..67d944b9c 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -37,6 +37,7 @@ def create_async_engine(*arg, **kw): "streaming result set" ) kw["future"] = True + kw["_is_async"] = True sync_engine = _create_engine(*arg, **kw) return AsyncEngine(sync_engine) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 8874f0a83..427251a88 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2338,26 +2338,40 @@ class JSON(Indexable, TypeEngine): def _str_impl(self): return String() - def bind_processor(self, dialect): - string_process = self._str_impl.bind_processor(dialect) + def _make_bind_processor(self, string_process, json_serializer): + if string_process: - json_serializer = dialect._json_serializer or json.dumps + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, elements.Null) or ( + value is None and self.none_as_null + ): + return None - def process(value): - if value is self.NULL: - value = None - elif isinstance(value, elements.Null) or ( - value is None and self.none_as_null - ): - return None + serialized = json_serializer(value) + return string_process(serialized) - serialized = json_serializer(value) - if string_process: - serialized = string_process(serialized) - return serialized + else: + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, elements.Null) or ( + value is None and self.none_as_null + ): + return None + + return json_serializer(value) return process + def bind_processor(self, dialect): + string_process = self._str_impl.bind_processor(dialect) + json_serializer = dialect._json_serializer or json.dumps + + return self._make_bind_processor(string_process, json_serializer) + def result_processor(self, dialect, coltype): string_process = self._str_impl.result_processor(dialect, coltype) json_deserializer = dialect._json_deserializer or json.loads diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 8faeea634..22d9c523a 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -124,11 +124,12 @@ class Config: _configs = set() def _set_name(self, db): + suffix = "_async" if db.dialect.is_async else "" if db.dialect.server_version_info: svi = ".".join(str(tok) for tok in db.dialect.server_version_info) - self.name = "%s+%s_[%s]" % (db.name, db.driver, svi) + self.name = "%s+%s%s_[%s]" % (db.name, db.driver, suffix, svi) else: - self.name = "%s+%s" % (db.name, db.driver) + self.name = "%s+%s%s" % (db.name, db.driver, suffix) @classmethod def register(cls, db, db_opts, options, file_config): diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 32ed2c315..d79931b91 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -706,7 +706,8 @@ def _do_skips(cls): ) if not all_configs: - msg = "'%s' unsupported on any DB implementation %s%s" % ( + msg = "'%s.%s' unsupported on any DB implementation %s%s" % ( + cls.__module__, cls.__name__, ", ".join( "'%s(%s)+%s'" diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 8cb72d163..f811be657 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1493,3 +1493,13 @@ class SuiteRequirements(Requirements): sequence. This should be false only for oracle. """ return exclusions.open() + + @property + def generic_classes(self): + "If X[Y] can be implemented with ``__class_getitem__``. py3.7+" + return exclusions.open() + + @property + def json_deserializer_binary(self): + "indicates if the json_deserializer function is called with bytes" + return exclusions.closed() diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 5ad68034b..a8900ece1 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -244,6 +244,8 @@ class ServerSideCursorsTest( return cursor.server_side elif self.engine.dialect.driver == "pg8000": return getattr(cursor, "server_side", False) + elif self.engine.dialect.driver == "psycopg": + return bool(getattr(cursor, "name", False)) else: return False diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 82e6fa238..e7131ec6e 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -1107,7 +1107,13 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): eq_(row, (data_element,)) eq_(js.mock_calls, [mock.call(data_element)]) - eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) + if testing.requires.json_deserializer_binary.enabled: + eq_( + jd.mock_calls, + [mock.call(json.dumps(data_element).encode())], + ) + else: + eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) @testing.combinations( ("parameters",), |