summaryrefslogtreecommitdiff
path: root/django/db/backends/postgresql/base.py
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2022-12-01 20:23:43 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-12-15 06:17:57 +0100
commit09ffc5c1212d4ced58b708cbbf3dfbfb77b782ca (patch)
tree15bb8bb049f9339f30d637e78b340473c2038126 /django/db/backends/postgresql/base.py
parentd44ee518c4c110af25bebdbedbbf9fba04d197aa (diff)
downloaddjango-09ffc5c1212d4ced58b708cbbf3dfbfb77b782ca.tar.gz
Fixed #33308 -- Added support for psycopg version 3.
Thanks Simon Charette, Tim Graham, and Adam Johnson for reviews. Co-authored-by: Florian Apolloner <florian@apolloner.eu> Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
Diffstat (limited to 'django/db/backends/postgresql/base.py')
-rw-r--r--django/db/backends/postgresql/base.py162
1 files changed, 114 insertions, 48 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index 0aee39aa5c..ceea1bebad 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -1,7 +1,7 @@
"""
PostgreSQL database backend for Django.
-Requires psycopg 2: https://www.psycopg.org/
+Requires psycopg2 >= 2.8.4 or psycopg >= 3.1
"""
import asyncio
@@ -21,48 +21,63 @@ from django.utils.safestring import SafeString
from django.utils.version import get_version_tuple
try:
- import psycopg2 as Database
- import psycopg2.extensions
- import psycopg2.extras
-except ImportError as e:
- raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e)
+ try:
+ import psycopg as Database
+ except ImportError:
+ import psycopg2 as Database
+except ImportError:
+ raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
-def psycopg2_version():
- version = psycopg2.__version__.split(" ", 1)[0]
+def psycopg_version():
+ version = Database.__version__.split(" ", 1)[0]
return get_version_tuple(version)
-PSYCOPG2_VERSION = psycopg2_version()
-
-if PSYCOPG2_VERSION < (2, 8, 4):
+if psycopg_version() < (2, 8, 4):
+ raise ImproperlyConfigured(
+ f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
+ )
+if (3,) <= psycopg_version() < (3, 1):
raise ImproperlyConfigured(
- "psycopg2 version 2.8.4 or newer is required; you have %s"
- % psycopg2.__version__
+ f"psycopg version 3.1 or newer is required; you have {Database.__version__}"
)
-# Some of these import psycopg2, so import them after checking if it's installed.
-from .client import DatabaseClient # NOQA
-from .creation import DatabaseCreation # NOQA
-from .features import DatabaseFeatures # NOQA
-from .introspection import DatabaseIntrospection # NOQA
-from .operations import DatabaseOperations # NOQA
-from .psycopg_any import IsolationLevel # NOQA
-from .schema import DatabaseSchemaEditor # NOQA
+from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
+
+if is_psycopg3:
+ from psycopg import adapters, sql
+ from psycopg.pq import Format
-psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
-psycopg2.extras.register_uuid()
+ from .psycopg_any import get_adapters_template, register_tzloader
-# Register support for inet[] manually so we don't have to handle the Inet()
-# object on load all the time.
-INETARRAY_OID = 1041
-INETARRAY = psycopg2.extensions.new_array_type(
- (INETARRAY_OID,),
- "INETARRAY",
- psycopg2.extensions.UNICODE,
-)
-psycopg2.extensions.register_type(INETARRAY)
+ TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
+
+else:
+ import psycopg2.extensions
+ import psycopg2.extras
+
+ psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
+ psycopg2.extras.register_uuid()
+
+ # Register support for inet[] manually so we don't have to handle the Inet()
+ # object on load all the time.
+ INETARRAY_OID = 1041
+ INETARRAY = psycopg2.extensions.new_array_type(
+ (INETARRAY_OID,),
+ "INETARRAY",
+ psycopg2.extensions.UNICODE,
+ )
+ psycopg2.extensions.register_type(INETARRAY)
+
+# Some of these import psycopg, so import them after checking if it's installed.
+from .client import DatabaseClient # NOQA isort:skip
+from .creation import DatabaseCreation # NOQA isort:skip
+from .features import DatabaseFeatures # NOQA isort:skip
+from .introspection import DatabaseIntrospection # NOQA isort:skip
+from .operations import DatabaseOperations # NOQA isort:skip
+from .schema import DatabaseSchemaEditor # NOQA isort:skip
class DatabaseWrapper(BaseDatabaseWrapper):
@@ -209,6 +224,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn_params["host"] = settings_dict["HOST"]
if settings_dict["PORT"]:
conn_params["port"] = settings_dict["PORT"]
+ if is_psycopg3:
+ conn_params["context"] = get_adapters_template(
+ settings.USE_TZ, self.timezone
+ )
+ # Disable prepared statements by default to keep connection poolers
+ # working. Can be reenabled via OPTIONS in the settings dict.
+ conn_params["prepare_threshold"] = conn_params.pop(
+ "prepare_threshold", None
+ )
return conn_params
@async_unsafe
@@ -232,17 +256,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
except ValueError:
raise ImproperlyConfigured(
f"Invalid transaction isolation level {isolation_level_value} "
- f"specified. Use one of the IsolationLevel values."
+ f"specified. Use one of the psycopg.IsolationLevel values."
)
- connection = Database.connect(**conn_params)
+ connection = self.Database.connect(**conn_params)
if set_isolation_level:
connection.isolation_level = self.isolation_level
- # Register dummy loads() to avoid a round trip from psycopg2's decode
- # to json.dumps() to json.loads(), when using a custom decoder in
- # JSONField.
- psycopg2.extras.register_default_jsonb(
- conn_or_curs=connection, loads=lambda x: x
- )
+ if not is_psycopg3:
+ # Register dummy loads() to avoid a round trip from psycopg2's
+ # decode to json.dumps() to json.loads(), when using a custom
+ # decoder in JSONField.
+ psycopg2.extras.register_default_jsonb(
+ conn_or_curs=connection, loads=lambda x: x
+ )
+ connection.cursor_factory = Cursor
return connection
def ensure_timezone(self):
@@ -275,7 +301,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
)
else:
cursor = self.connection.cursor()
- cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
+
+ if is_psycopg3:
+ # Register the cursor timezone only if the connection disagrees, to
+ # avoid copying the adapter map.
+ tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
+ if self.timezone != tzloader.timezone:
+ register_tzloader(self.timezone, cursor)
+ else:
+ cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
return cursor
def tzinfo_factory(self, offset):
@@ -379,11 +413,43 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return CursorDebugWrapper(cursor, self)
-class CursorDebugWrapper(BaseCursorDebugWrapper):
- def copy_expert(self, sql, file, *args):
- with self.debug_sql(sql):
- return self.cursor.copy_expert(sql, file, *args)
+if is_psycopg3:
+
+ class Cursor(Database.Cursor):
+ """
+ A subclass of psycopg cursor implementing callproc.
+ """
+
+ def callproc(self, name, args=None):
+ if not isinstance(name, sql.Identifier):
+ name = sql.Identifier(name)
+
+ qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
+ if args:
+ for item in args:
+ qparts.append(sql.Literal(item))
+ qparts.append(sql.SQL(","))
+ del qparts[-1]
+
+ qparts.append(sql.SQL(")"))
+ stmt = sql.Composed(qparts)
+ self.execute(stmt)
+ return args
+
+ class CursorDebugWrapper(BaseCursorDebugWrapper):
+ def copy(self, statement):
+ with self.debug_sql(statement):
+ return self.cursor.copy(statement)
+
+else:
+
+ Cursor = psycopg2.extensions.cursor
+
+ class CursorDebugWrapper(BaseCursorDebugWrapper):
+ def copy_expert(self, sql, file, *args):
+ with self.debug_sql(sql):
+ return self.cursor.copy_expert(sql, file, *args)
- def copy_to(self, file, table, *args, **kwargs):
- with self.debug_sql(sql="COPY %s TO STDOUT" % table):
- return self.cursor.copy_to(file, table, *args, **kwargs)
+ def copy_to(self, file, table, *args, **kwargs):
+ with self.debug_sql(sql="COPY %s TO STDOUT" % table):
+ return self.cursor.copy_to(file, table, *args, **kwargs)