summaryrefslogtreecommitdiff
path: root/django/db/backends/postgresql/schema.py
diff options
context:
space:
mode:
authorFlorian Apolloner <florian@apolloner.eu>2022-03-24 16:46:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-04-13 21:51:51 +0200
commit2eea361eff58dd98c409c5227064b901f41bd0d6 (patch)
treeb551c3b45c3d0f133d88ef346ec1657f937d2892 /django/db/backends/postgresql/schema.py
parent62ffc9883afdc0a9f9674702661062508230d7bf (diff)
downloaddjango-2eea361eff58dd98c409c5227064b901f41bd0d6.tar.gz
Fixed #30511 -- Used identity columns instead of serials on PostgreSQL.
Diffstat (limited to 'django/db/backends/postgresql/schema.py')
-rw-r--r--django/db/backends/postgresql/schema.py86
1 files changed, 30 insertions, 56 deletions
diff --git a/django/db/backends/postgresql/schema.py b/django/db/backends/postgresql/schema.py
index 73e2749020..3053c8d370 100644
--- a/django/db/backends/postgresql/schema.py
+++ b/django/db/backends/postgresql/schema.py
@@ -7,12 +7,7 @@ from django.db.backends.utils import strip_quotes
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
- sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
- sql_set_sequence_max = (
- "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
- )
- sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s"
sql_create_index = (
"CREATE INDEX %(name)s ON %(table)s%(using)s "
@@ -39,6 +34,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
)
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
+ sql_add_identity = (
+ "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
+ "GENERATED BY DEFAULT AS IDENTITY"
+ )
+ sql_drop_indentity = (
+ "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
+ )
+
def quote_value(self, value):
if isinstance(value, str):
value = value.replace("%", "%%")
@@ -116,78 +119,47 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
self.sql_alter_column_type += using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += using_sql
- # Make ALTER TYPE with SERIAL make sense.
+ # Make ALTER TYPE with IDENTITY make sense.
table = strip_quotes(model._meta.db_table)
- serial_fields_map = {
- "bigserial": "bigint",
- "serial": "integer",
- "smallserial": "smallint",
+ auto_field_types = {
+ "AutoField",
+ "BigAutoField",
+ "SmallAutoField",
}
- if new_type.lower() in serial_fields_map:
+ old_is_auto = old_internal_type in auto_field_types
+ new_is_auto = new_internal_type in auto_field_types
+ if new_is_auto and not old_is_auto:
column = strip_quotes(new_field.column)
- sequence_name = "%s_%s_seq" % (table, column)
return (
(
self.sql_alter_column_type
% {
"column": self.quote_name(column),
- "type": serial_fields_map[new_type.lower()],
+ "type": new_type,
},
[],
),
[
(
- self.sql_delete_sequence
- % {
- "sequence": self.quote_name(sequence_name),
- },
- [],
- ),
- (
- self.sql_create_sequence
- % {
- "sequence": self.quote_name(sequence_name),
- },
- [],
- ),
- (
- self.sql_alter_column
- % {
- "table": self.quote_name(table),
- "changes": self.sql_alter_column_default
- % {
- "column": self.quote_name(column),
- "default": "nextval('%s')"
- % self.quote_name(sequence_name),
- },
- },
- [],
- ),
- (
- self.sql_set_sequence_max
+ self.sql_add_identity
% {
"table": self.quote_name(table),
"column": self.quote_name(column),
- "sequence": self.quote_name(sequence_name),
- },
- [],
- ),
- (
- self.sql_set_sequence_owner
- % {
- "table": self.quote_name(table),
- "column": self.quote_name(column),
- "sequence": self.quote_name(sequence_name),
},
[],
),
],
)
- elif (
- old_field.db_parameters(connection=self.connection)["type"]
- in serial_fields_map
- ):
- # Drop the sequence if migrating away from AutoField.
+ elif old_is_auto and not new_is_auto:
+ # Drop IDENTITY if exists (pre-Django 4.1 serial columns don't have
+ # it).
+ self.execute(
+ self.sql_drop_indentity
+ % {
+ "table": self.quote_name(table),
+ "column": self.quote_name(strip_quotes(old_field.column)),
+ }
+ )
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
fragment, _ = super()._alter_column_type_sql(
@@ -195,6 +167,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
)
return fragment, [
(
+ # Drop the sequence if exists (Django 4.1+ identity columns
+ # don't have it).
self.sql_delete_sequence
% {
"sequence": self.quote_name(sequence_name),