diff options
author | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2022-12-01 20:23:43 +0100 |
---|---|---|
committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-12-15 06:17:57 +0100 |
commit | 09ffc5c1212d4ced58b708cbbf3dfbfb77b782ca (patch) | |
tree | 15bb8bb049f9339f30d637e78b340473c2038126 /django/db/backends/postgresql/base.py | |
parent | d44ee518c4c110af25bebdbedbbf9fba04d197aa (diff) | |
download | django-09ffc5c1212d4ced58b708cbbf3dfbfb77b782ca.tar.gz |
Fixed #33308 -- Added support for psycopg version 3.
Thanks Simon Charette, Tim Graham, and Adam Johnson for reviews.
Co-authored-by: Florian Apolloner <florian@apolloner.eu>
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
Diffstat (limited to 'django/db/backends/postgresql/base.py')
-rw-r--r-- | django/db/backends/postgresql/base.py | 162 |
1 files changed, 114 insertions, 48 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 0aee39aa5c..ceea1bebad 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -1,7 +1,7 @@ """ PostgreSQL database backend for Django. -Requires psycopg 2: https://www.psycopg.org/ +Requires psycopg2 >= 2.8.4 or psycopg >= 3.1 """ import asyncio @@ -21,48 +21,63 @@ from django.utils.safestring import SafeString from django.utils.version import get_version_tuple try: - import psycopg2 as Database - import psycopg2.extensions - import psycopg2.extras -except ImportError as e: - raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e) + try: + import psycopg as Database + except ImportError: + import psycopg2 as Database +except ImportError: + raise ImproperlyConfigured("Error loading psycopg2 or psycopg module") -def psycopg2_version(): - version = psycopg2.__version__.split(" ", 1)[0] +def psycopg_version(): + version = Database.__version__.split(" ", 1)[0] return get_version_tuple(version) -PSYCOPG2_VERSION = psycopg2_version() - -if PSYCOPG2_VERSION < (2, 8, 4): +if psycopg_version() < (2, 8, 4): + raise ImproperlyConfigured( + f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}" + ) +if (3,) <= psycopg_version() < (3, 1): raise ImproperlyConfigured( - "psycopg2 version 2.8.4 or newer is required; you have %s" - % psycopg2.__version__ + f"psycopg version 3.1 or newer is required; you have {Database.__version__}" ) -# Some of these import psycopg2, so import them after checking if it's installed. -from .client import DatabaseClient # NOQA -from .creation import DatabaseCreation # NOQA -from .features import DatabaseFeatures # NOQA -from .introspection import DatabaseIntrospection # NOQA -from .operations import DatabaseOperations # NOQA -from .psycopg_any import IsolationLevel # NOQA -from .schema import DatabaseSchemaEditor # NOQA +from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip + +if is_psycopg3: + from psycopg import adapters, sql + from psycopg.pq import Format -psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString) -psycopg2.extras.register_uuid() + from .psycopg_any import get_adapters_template, register_tzloader -# Register support for inet[] manually so we don't have to handle the Inet() -# object on load all the time. -INETARRAY_OID = 1041 -INETARRAY = psycopg2.extensions.new_array_type( - (INETARRAY_OID,), - "INETARRAY", - psycopg2.extensions.UNICODE, -) -psycopg2.extensions.register_type(INETARRAY) + TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid + +else: + import psycopg2.extensions + import psycopg2.extras + + psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString) + psycopg2.extras.register_uuid() + + # Register support for inet[] manually so we don't have to handle the Inet() + # object on load all the time. + INETARRAY_OID = 1041 + INETARRAY = psycopg2.extensions.new_array_type( + (INETARRAY_OID,), + "INETARRAY", + psycopg2.extensions.UNICODE, + ) + psycopg2.extensions.register_type(INETARRAY) + +# Some of these import psycopg, so import them after checking if it's installed. +from .client import DatabaseClient # NOQA isort:skip +from .creation import DatabaseCreation # NOQA isort:skip +from .features import DatabaseFeatures # NOQA isort:skip +from .introspection import DatabaseIntrospection # NOQA isort:skip +from .operations import DatabaseOperations # NOQA isort:skip +from .schema import DatabaseSchemaEditor # NOQA isort:skip class DatabaseWrapper(BaseDatabaseWrapper): @@ -209,6 +224,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): conn_params["host"] = settings_dict["HOST"] if settings_dict["PORT"]: conn_params["port"] = settings_dict["PORT"] + if is_psycopg3: + conn_params["context"] = get_adapters_template( + settings.USE_TZ, self.timezone + ) + # Disable prepared statements by default to keep connection poolers + # working. Can be reenabled via OPTIONS in the settings dict. + conn_params["prepare_threshold"] = conn_params.pop( + "prepare_threshold", None + ) return conn_params @async_unsafe @@ -232,17 +256,19 @@ class DatabaseWrapper(BaseDatabaseWrapper): except ValueError: raise ImproperlyConfigured( f"Invalid transaction isolation level {isolation_level_value} " - f"specified. Use one of the IsolationLevel values." + f"specified. Use one of the psycopg.IsolationLevel values." ) - connection = Database.connect(**conn_params) + connection = self.Database.connect(**conn_params) if set_isolation_level: connection.isolation_level = self.isolation_level - # Register dummy loads() to avoid a round trip from psycopg2's decode - # to json.dumps() to json.loads(), when using a custom decoder in - # JSONField. - psycopg2.extras.register_default_jsonb( - conn_or_curs=connection, loads=lambda x: x - ) + if not is_psycopg3: + # Register dummy loads() to avoid a round trip from psycopg2's + # decode to json.dumps() to json.loads(), when using a custom + # decoder in JSONField. + psycopg2.extras.register_default_jsonb( + conn_or_curs=connection, loads=lambda x: x + ) + connection.cursor_factory = Cursor return connection def ensure_timezone(self): @@ -275,7 +301,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): ) else: cursor = self.connection.cursor() - cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None + + if is_psycopg3: + # Register the cursor timezone only if the connection disagrees, to + # avoid copying the adapter map. + tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT) + if self.timezone != tzloader.timezone: + register_tzloader(self.timezone, cursor) + else: + cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None return cursor def tzinfo_factory(self, offset): @@ -379,11 +413,43 @@ class DatabaseWrapper(BaseDatabaseWrapper): return CursorDebugWrapper(cursor, self) -class CursorDebugWrapper(BaseCursorDebugWrapper): - def copy_expert(self, sql, file, *args): - with self.debug_sql(sql): - return self.cursor.copy_expert(sql, file, *args) +if is_psycopg3: + + class Cursor(Database.Cursor): + """ + A subclass of psycopg cursor implementing callproc. + """ + + def callproc(self, name, args=None): + if not isinstance(name, sql.Identifier): + name = sql.Identifier(name) + + qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")] + if args: + for item in args: + qparts.append(sql.Literal(item)) + qparts.append(sql.SQL(",")) + del qparts[-1] + + qparts.append(sql.SQL(")")) + stmt = sql.Composed(qparts) + self.execute(stmt) + return args + + class CursorDebugWrapper(BaseCursorDebugWrapper): + def copy(self, statement): + with self.debug_sql(statement): + return self.cursor.copy(statement) + +else: + + Cursor = psycopg2.extensions.cursor + + class CursorDebugWrapper(BaseCursorDebugWrapper): + def copy_expert(self, sql, file, *args): + with self.debug_sql(sql): + return self.cursor.copy_expert(sql, file, *args) - def copy_to(self, file, table, *args, **kwargs): - with self.debug_sql(sql="COPY %s TO STDOUT" % table): - return self.cursor.copy_to(file, table, *args, **kwargs) + def copy_to(self, file, table, *args, **kwargs): + with self.debug_sql(sql="COPY %s TO STDOUT" % table): + return self.cursor.copy_to(file, table, *args, **kwargs) |