diff options
Diffstat (limited to 'django/db/backends/postgresql/operations.py')
-rw-r--r-- | django/db/backends/postgresql/operations.py | 155 |
1 files changed, 94 insertions, 61 deletions
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 762cd8d23e..68448157ec 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -7,17 +7,22 @@ from django.db.models.constants import OnConflict class DatabaseOperations(BaseDatabaseOperations): - cast_char_field_without_max_length = 'varchar' - explain_prefix = 'EXPLAIN' + cast_char_field_without_max_length = "varchar" + explain_prefix = "EXPLAIN" cast_data_types = { - 'AutoField': 'integer', - 'BigAutoField': 'bigint', - 'SmallAutoField': 'smallint', + "AutoField": "integer", + "BigAutoField": "bigint", + "SmallAutoField": "smallint", } def unification_cast_sql(self, output_field): internal_type = output_field.get_internal_type() - if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"): + if internal_type in ( + "GenericIPAddressField", + "IPAddressField", + "TimeField", + "UUIDField", + ): # PostgreSQL will resolve a union as type 'text' if input types are # 'unknown'. # https://www.postgresql.org/docs/current/typeconv-union-case.html @@ -25,17 +30,19 @@ class DatabaseOperations(BaseDatabaseOperations): # PostgreSQL configuration so we need to explicitly cast them. # We must also remove components of the type within brackets: # varchar(255) -> varchar. - return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0] - return '%s' + return ( + "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0] + ) + return "%s" def date_extract_sql(self, lookup_type, field_name): # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT - if lookup_type == 'week_day': + if lookup_type == "week_day": # For consistency across backends, we return Sunday=1, Saturday=7. return "EXTRACT('dow' FROM %s) + 1" % field_name - elif lookup_type == 'iso_week_day': + elif lookup_type == "iso_week_day": return "EXTRACT('isodow' FROM %s)" % field_name - elif lookup_type == 'iso_year': + elif lookup_type == "iso_year": return "EXTRACT('isoyear' FROM %s)" % field_name else: return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) @@ -48,22 +55,25 @@ class DatabaseOperations(BaseDatabaseOperations): def _prepare_tzname_delta(self, tzname): tzname, sign, offset = split_tzname_delta(tzname) if offset: - sign = '-' if sign == '+' else '+' - return f'{tzname}{sign}{offset}' + sign = "-" if sign == "+" else "+" + return f"{tzname}{sign}{offset}" return tzname def _convert_field_to_tz(self, field_name, tzname): if tzname and settings.USE_TZ: - field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname)) + field_name = "%s AT TIME ZONE '%s'" % ( + field_name, + self._prepare_tzname_delta(tzname), + ) return field_name def datetime_cast_date_sql(self, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) - return '(%s)::date' % field_name + return "(%s)::date" % field_name def datetime_cast_time_sql(self, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) - return '(%s)::time' % field_name + return "(%s)::time" % field_name def datetime_extract_sql(self, lookup_type, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) @@ -89,21 +99,30 @@ class DatabaseOperations(BaseDatabaseOperations): return cursor.fetchall() def lookup_cast(self, lookup_type, internal_type=None): - lookup = '%s' + lookup = "%s" # Cast text lookups to text to allow things like filter(x__contains=4) - if lookup_type in ('iexact', 'contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'): - if internal_type in ('IPAddressField', 'GenericIPAddressField'): + if lookup_type in ( + "iexact", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "regex", + "iregex", + ): + if internal_type in ("IPAddressField", "GenericIPAddressField"): lookup = "HOST(%s)" - elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'): - lookup = '%s::citext' + elif internal_type in ("CICharField", "CIEmailField", "CITextField"): + lookup = "%s::citext" else: lookup = "%s::text" # Use UPPER(x) for case-insensitive lookups; it's faster. - if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): - lookup = 'UPPER(%s)' % lookup + if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"): + lookup = "UPPER(%s)" % lookup return lookup @@ -128,29 +147,32 @@ class DatabaseOperations(BaseDatabaseOperations): # Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us # to truncate tables referenced by a foreign key in any other table. sql_parts = [ - style.SQL_KEYWORD('TRUNCATE'), - ', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables), + style.SQL_KEYWORD("TRUNCATE"), + ", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables), ] if reset_sequences: - sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY')) + sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY")) if allow_cascade: - sql_parts.append(style.SQL_KEYWORD('CASCADE')) - return ['%s;' % ' '.join(sql_parts)] + sql_parts.append(style.SQL_KEYWORD("CASCADE")) + return ["%s;" % " ".join(sql_parts)] def sequence_reset_by_name_sql(self, style, sequences): # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements # to reset sequence indices sql = [] for sequence_info in sequences: - table_name = sequence_info['table'] + table_name = sequence_info["table"] # 'id' will be the case if it's an m2m using an autogenerated # intermediate table (see BaseDatabaseIntrospection.sequence_list). - column_name = sequence_info['column'] or 'id' - sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % ( - style.SQL_KEYWORD('SELECT'), - style.SQL_TABLE(self.quote_name(table_name)), - style.SQL_FIELD(column_name), - )) + column_name = sequence_info["column"] or "id" + sql.append( + "%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" + % ( + style.SQL_KEYWORD("SELECT"), + style.SQL_TABLE(self.quote_name(table_name)), + style.SQL_FIELD(column_name), + ) + ) return sql def tablespace_sql(self, tablespace, inline=False): @@ -161,6 +183,7 @@ class DatabaseOperations(BaseDatabaseOperations): def sequence_reset_sql(self, style, model_list): from django.db import models + output = [] qn = self.quote_name for model in model_list: @@ -174,14 +197,15 @@ class DatabaseOperations(BaseDatabaseOperations): if isinstance(f, models.AutoField): output.append( "%s setval(pg_get_serial_sequence('%s','%s'), " - "coalesce(max(%s), 1), max(%s) %s null) %s %s;" % ( - style.SQL_KEYWORD('SELECT'), + "coalesce(max(%s), 1), max(%s) %s null) %s %s;" + % ( + style.SQL_KEYWORD("SELECT"), style.SQL_TABLE(qn(model._meta.db_table)), style.SQL_FIELD(f.column), style.SQL_FIELD(qn(f.column)), style.SQL_FIELD(qn(f.column)), - style.SQL_KEYWORD('IS NOT'), - style.SQL_KEYWORD('FROM'), + style.SQL_KEYWORD("IS NOT"), + style.SQL_KEYWORD("FROM"), style.SQL_TABLE(qn(model._meta.db_table)), ) ) @@ -207,9 +231,9 @@ class DatabaseOperations(BaseDatabaseOperations): def distinct_sql(self, fields, params): if fields: params = [param for param_list in params for param in param_list] - return (['DISTINCT ON (%s)' % ', '.join(fields)], params) + return (["DISTINCT ON (%s)" % ", ".join(fields)], params) else: - return ['DISTINCT'], [] + return ["DISTINCT"], [] def last_executed_query(self, cursor, sql, params): # https://www.psycopg.org/docs/cursor.html#cursor.query @@ -220,14 +244,16 @@ class DatabaseOperations(BaseDatabaseOperations): def return_insert_columns(self, fields): if not fields: - return '', () + return "", () columns = [ - '%s.%s' % ( + "%s.%s" + % ( self.quote_name(field.model._meta.db_table), self.quote_name(field.column), - ) for field in fields + ) + for field in fields ] - return 'RETURNING %s' % ', '.join(columns), () + return "RETURNING %s" % ", ".join(columns), () def bulk_insert_sql(self, fields, placeholder_rows): placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) @@ -252,7 +278,7 @@ class DatabaseOperations(BaseDatabaseOperations): return None def subtract_temporals(self, internal_type, lhs, rhs): - if internal_type == 'DateField': + if internal_type == "DateField": lhs_sql, lhs_params = lhs rhs_sql, rhs_params = rhs params = (*lhs_params, *rhs_params) @@ -263,27 +289,34 @@ class DatabaseOperations(BaseDatabaseOperations): prefix = super().explain_query_prefix(format) extra = {} if format: - extra['FORMAT'] = format + extra["FORMAT"] = format if options: - extra.update({ - name.upper(): 'true' if value else 'false' - for name, value in options.items() - }) + extra.update( + { + name.upper(): "true" if value else "false" + for name, value in options.items() + } + ) if extra: - prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items()) + prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items()) return prefix def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): if on_conflict == OnConflict.IGNORE: - return 'ON CONFLICT DO NOTHING' + return "ON CONFLICT DO NOTHING" if on_conflict == OnConflict.UPDATE: - return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( - ', '.join(map(self.quote_name, unique_fields)), - ', '.join([ - f'{field} = EXCLUDED.{field}' - for field in map(self.quote_name, update_fields) - ]), + return "ON CONFLICT(%s) DO UPDATE SET %s" % ( + ", ".join(map(self.quote_name, unique_fields)), + ", ".join( + [ + f"{field} = EXCLUDED.{field}" + for field in map(self.quote_name, update_fields) + ] + ), ) return super().on_conflict_suffix_sql( - fields, on_conflict, update_fields, unique_fields, + fields, + on_conflict, + update_fields, + unique_fields, ) |