summaryrefslogtreecommitdiff
path: root/alembic
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2023-04-06 22:16:41 +0200
committerFederico Caselli <cfederico87@gmail.com>2023-04-10 21:58:17 +0200
commit157c521736f1c9cfceb9b3a6ecf17f782d358c46 (patch)
tree8c38a3be951bee504a51625dc0ca01d0fb888f4c /alembic
parent3d9b1128cd6bf03ecb45003587c0eedfb9552b07 (diff)
downloadalembic-157c521736f1c9cfceb9b3a6ecf17f782d358c46.tar.gz
Use column sort in index compare on postgresql
Added support for autogenerate comparison of indexes on PostgreSQL which include SQL sort option, such as ``ASC`` or ``NULLS FIRST``. Fixes: #1213 Change-Id: I3ddcb647928d948e41462b1c889b1cbb515ace4f
Diffstat (limited to 'alembic')
-rw-r--r--alembic/autogenerate/compare.py24
-rw-r--r--alembic/ddl/postgresql.py73
2 files changed, 82 insertions, 15 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 4f5126f..85cb426 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -8,6 +8,7 @@ from typing import cast
from typing import Dict
from typing import Iterator
from typing import List
+from typing import Mapping
from typing import Optional
from typing import Set
from typing import Tuple
@@ -19,6 +20,7 @@ from sqlalchemy import inspect
from sqlalchemy import schema as sa_schema
from sqlalchemy import text
from sqlalchemy import types as sqltypes
+from sqlalchemy.sql import expression
from sqlalchemy.util import OrderedSet
from alembic.ddl.base import _fk_spec
@@ -278,15 +280,35 @@ def _compare_tables(
upgrade_ops.ops.append(modify_table_ops)
+_IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict(
+ {
+ "asc": expression.asc,
+ "desc": expression.desc,
+ "nulls_first": expression.nullsfirst,
+ "nulls_last": expression.nullslast,
+ "nullsfirst": expression.nullsfirst, # 1_3 name
+ "nullslast": expression.nullslast, # 1_3 name
+ }
+)
+
+
def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]:
exprs: list[Union[Column[Any], TextClause]] = []
+ sorting = params.get("column_sorting")
+
for num, col_name in enumerate(params["column_names"]):
item: Union[Column[Any], TextClause]
if col_name is None:
assert "expressions" in params
- item = text(params["expressions"][num])
+ name = params["expressions"][num]
+ item = text(name)
else:
+ name = col_name
item = conn_table.c[col_name]
+ if sorting and name in sorting:
+ for operator in sorting[name]:
+ if operator in _IndexColumnSortingOps:
+ item = _IndexColumnSortingOps[operator](item)
exprs.append(item)
ix = sa_schema.Index(
params["name"], *exprs, unique=params["unique"], _table=conn_table
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index 4ffc2eb..247838b 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -21,8 +21,10 @@ from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
+from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.types import NULLTYPE
from .base import alter_column
@@ -53,6 +55,7 @@ if TYPE_CHECKING:
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
@@ -248,11 +251,14 @@ class PostgresqlImpl(DefaultImpl):
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
- def _cleanup_index_expr(self, index: Index, expr: str) -> str:
+ def _cleanup_index_expr(
+ self, index: Index, expr: str, remove_suffix: str
+ ) -> str:
# start = expr
expr = expr.lower()
expr = expr.replace('"', "")
if index.table is not None:
+ # should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
while expr and expr[0] == "(" and expr[-1] == ")":
@@ -261,25 +267,64 @@ class PostgresqlImpl(DefaultImpl):
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
+ if remove_suffix and expr.endswith(remove_suffix):
+ expr = expr[: -len(remove_suffix)]
+
# print(f"START: {start} END: {expr}")
return expr
+ def _default_modifiers(self, exp: ClauseElement) -> str:
+ to_remove = ""
+ while isinstance(exp, UnaryExpression):
+ if exp.modifier is None:
+ exp = exp.element
+ else:
+ op = exp.modifier
+ if isinstance(exp.element, UnaryExpression):
+ inner_op = exp.element.modifier
+ else:
+ inner_op = None
+ if inner_op is None:
+ if op == operators.asc_op:
+ # default is asc
+ to_remove = " asc"
+ elif op == operators.nullslast_op:
+ # default is nulls last
+ to_remove = " nulls last"
+ else:
+ if (
+ inner_op == operators.asc_op
+ and op == operators.nullslast_op
+ ):
+ # default is asc nulls last
+ to_remove = " asc nulls last"
+ elif (
+ inner_op == operators.desc_op
+ and op == operators.nullsfirst_op
+ ):
+ # default for desc is nulls first
+ to_remove = " nulls first"
+ break
+ return to_remove
+
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
- if sqla_compat.is_expression_index(index):
- return tuple(
- self._cleanup_index_expr(
- index,
- e
+ return tuple(
+ self._cleanup_index_expr(
+ index,
+ *(
+ (e, "")
if isinstance(e, str)
- else e.compile(
- dialect=self.dialect,
- compile_kwargs={"literal_binds": True},
- ).string,
- )
- for e in index.expressions
+ else (self._compile_element(e), self._default_modifiers(e))
+ ),
)
- else:
- return super().create_index_sig(index)
+ for e in index.expressions
+ )
+
+ def _compile_element(self, element: ClauseElement) -> str:
+ return element.compile(
+ dialect=self.dialect,
+ compile_kwargs={"literal_binds": True, "include_table": False},
+ ).string
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext