diff options
Diffstat (limited to 'django/db/backends/postgresql/operations.py')
-rw-r--r-- | django/db/backends/postgresql/operations.py | 90 |
1 files changed, 67 insertions, 23 deletions
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 824e0c3e4b..18cfcb29cb 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -3,9 +3,16 @@ from functools import lru_cache, partial from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations -from django.db.backends.postgresql.psycopg_any import Inet, Jsonb, mogrify +from django.db.backends.postgresql.psycopg_any import ( + Inet, + Jsonb, + errors, + is_psycopg3, + mogrify, +) from django.db.backends.utils import split_tzname_delta from django.db.models.constants import OnConflict +from django.utils.regex_helper import _lazy_re_compile @lru_cache @@ -36,6 +43,18 @@ class DatabaseOperations(BaseDatabaseOperations): "SmallAutoField": "smallint", } + if is_psycopg3: + from psycopg.types import numeric + + integerfield_type_map = { + "SmallIntegerField": numeric.Int2, + "IntegerField": numeric.Int4, + "BigIntegerField": numeric.Int8, + "PositiveSmallIntegerField": numeric.Int2, + "PositiveIntegerField": numeric.Int4, + "PositiveBigIntegerField": numeric.Int8, + } + def unification_cast_sql(self, output_field): internal_type = output_field.get_internal_type() if internal_type in ( @@ -56,19 +75,23 @@ class DatabaseOperations(BaseDatabaseOperations): ) return "%s" + # EXTRACT format cannot be passed in parameters. + _extract_format_re = _lazy_re_compile(r"[A-Z_]+") + def date_extract_sql(self, lookup_type, sql, params): # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT - extract_sql = f"EXTRACT(%s FROM {sql})" - extract_param = lookup_type if lookup_type == "week_day": # For consistency across backends, we return Sunday=1, Saturday=7. - extract_sql = f"EXTRACT(%s FROM {sql}) + 1" - extract_param = "dow" + return f"EXTRACT(DOW FROM {sql}) + 1", params elif lookup_type == "iso_week_day": - extract_param = "isodow" + return f"EXTRACT(ISODOW FROM {sql})", params elif lookup_type == "iso_year": - extract_param = "isoyear" - return extract_sql, (extract_param, *params) + return f"EXTRACT(ISOYEAR FROM {sql})", params + + lookup_type = lookup_type.upper() + if not self._extract_format_re.fullmatch(lookup_type): + raise ValueError(f"Invalid lookup type: {lookup_type!r}") + return f"EXTRACT({lookup_type} FROM {sql})", params def date_trunc_sql(self, lookup_type, sql, params, tzname=None): sql, params = self._convert_sql_to_tz(sql, params, tzname) @@ -100,10 +123,7 @@ class DatabaseOperations(BaseDatabaseOperations): sql, params = self._convert_sql_to_tz(sql, params, tzname) if lookup_type == "second": # Truncate fractional seconds. - return ( - f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))", - ("second", "second", *params), - ) + return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params) return self.date_extract_sql(lookup_type, sql, params) def datetime_trunc_sql(self, lookup_type, sql, params, tzname): @@ -114,10 +134,7 @@ class DatabaseOperations(BaseDatabaseOperations): def time_extract_sql(self, lookup_type, sql, params): if lookup_type == "second": # Truncate fractional seconds. - return ( - f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))", - ("second", "second", *params), - ) + return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params) return self.date_extract_sql(lookup_type, sql, params) def time_trunc_sql(self, lookup_type, sql, params, tzname=None): @@ -137,6 +154,16 @@ class DatabaseOperations(BaseDatabaseOperations): def lookup_cast(self, lookup_type, internal_type=None): lookup = "%s" + if lookup_type == "isnull" and internal_type in ( + "CharField", + "EmailField", + "TextField", + "CICharField", + "CIEmailField", + "CITextField", + ): + return "%s::text" + # Cast text lookups to text to allow things like filter(x__contains=4) if lookup_type in ( "iexact", @@ -178,7 +205,7 @@ class DatabaseOperations(BaseDatabaseOperations): return mogrify(sql, params, self.connection) def set_time_zone_sql(self): - return "SET TIME ZONE %s" + return "SELECT set_config('TimeZone', %s, false)" def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): if not tables: @@ -278,12 +305,22 @@ class DatabaseOperations(BaseDatabaseOperations): else: return ["DISTINCT"], [] - def last_executed_query(self, cursor, sql, params): - # https://www.psycopg.org/docs/cursor.html#cursor.query - # The query attribute is a Psycopg extension to the DB API 2.0. - if cursor.query is not None: - return cursor.query.decode() - return None + if is_psycopg3: + + def last_executed_query(self, cursor, sql, params): + try: + return self.compose_sql(sql, params) + except errors.DataError: + return None + + else: + + def last_executed_query(self, cursor, sql, params): + # https://www.psycopg.org/docs/cursor.html#cursor.query + # The query attribute is a Psycopg extension to the DB API 2.0. + if cursor.query is not None: + return cursor.query.decode() + return None def return_insert_columns(self, fields): if not fields: @@ -303,6 +340,13 @@ class DatabaseOperations(BaseDatabaseOperations): values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) return "VALUES " + values_sql + if is_psycopg3: + + def adapt_integerfield_value(self, value, internal_type): + if value is None or hasattr(value, "resolve_expression"): + return value + return self.integerfield_type_map[internal_type](value) + def adapt_datefield_value(self, value): return value |