summaryrefslogtreecommitdiff
path: root/alembic
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-03-01 21:28:01 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-03-01 21:28:01 +0000
commit5df0ec12369db7f0bd1c5b273be21a6a79a5a895 (patch)
tree65e73e3a31c90a7ab11f647c476aa142a15103e5 /alembic
parent192e4de9964e11f7f4f04472af7a1e09b0356a1c (diff)
parent5c71fd120af824d289cf72dd4b5679a4f839a1eb (diff)
downloadalembic-5df0ec12369db7f0bd1c5b273be21a6a79a5a895.tar.gz
Merge "Improved support for expression indexes" into main
Diffstat (limited to 'alembic')
-rw-r--r--alembic/autogenerate/compare.py28
-rw-r--r--alembic/ddl/impl.py21
-rw-r--r--alembic/ddl/postgresql.py56
-rw-r--r--alembic/ddl/sqlite.py10
-rw-r--r--alembic/testing/requirements.py7
-rw-r--r--alembic/testing/schemacompare.py16
-rw-r--r--alembic/testing/util.py6
-rw-r--r--alembic/util/__init__.py1
-rw-r--r--alembic/util/sqla_compat.py16
9 files changed, 120 insertions, 41 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index c2181b8..f50c41c 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -17,6 +17,7 @@ from typing import Union
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy import schema as sa_schema
+from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.util import OrderedSet
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
@@ -37,6 +39,7 @@ if TYPE_CHECKING:
from sqlalchemy.sql.schema import UniqueConstraint
from alembic.autogenerate.api import AutogenContext
+ from alembic.ddl.impl import DefaultImpl
from alembic.operations.ops import AlterColumnOp
from alembic.operations.ops import MigrationScript
from alembic.operations.ops import ModifyTableOps
@@ -276,14 +279,12 @@ def _compare_tables(
def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]:
- exprs = []
- for col_name in params["column_names"]:
+ exprs: list[Union[Column[Any], TextClause]] = []
+ for num, col_name in enumerate(params["column_names"]):
+ item: Union[Column[Any], TextClause]
if col_name is None:
- util.warn(
- "Skipping reflected expression-based "
- f"index {params['name']!r}"
- )
- return None
+ assert "expressions" in params
+ item = text(params["expressions"][num])
else:
item = conn_table.c[col_name]
exprs.append(item)
@@ -439,10 +440,10 @@ class _uq_constraint_sig(_constraint_sig):
class _ix_constraint_sig(_constraint_sig):
is_index = True
- def __init__(self, const: Index) -> None:
+ def __init__(self, const: Index, impl: DefaultImpl) -> None:
self.const = const
self.name = const.name
- self.sig = tuple(sorted([col.name for col in const.columns]))
+ self.sig = impl.create_index_sig(const)
self.is_unique = bool(const.unique)
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
@@ -624,11 +625,14 @@ def _compare_indexes_and_uniques(
_uq_constraint_sig(uq) for uq in metadata_unique_constraints
}
- metadata_indexes_sig = {_ix_constraint_sig(ix) for ix in metadata_indexes}
+ impl = autogen_context.migration_context.impl
+ metadata_indexes_sig = {
+ _ix_constraint_sig(ix, impl) for ix in metadata_indexes
+ }
conn_unique_constraints = {_uq_constraint_sig(uq) for uq in conn_uniques}
- conn_indexes_sig = {_ix_constraint_sig(ix) for ix in conn_indexes}
+ conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes}
# 5. index things by name, for those objects that have names
metadata_names = {
@@ -815,7 +819,7 @@ def _compare_indexes_and_uniques(
)
if conn_obj.sig != metadata_obj.sig:
msg.append(
- " columns %r to %r" % (conn_obj.sig, metadata_obj.sig)
+ " expression %r to %r" % (conn_obj.sig, metadata_obj.sig)
)
if msg:
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 728d1da..f11d1ed 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -665,6 +665,27 @@ class DefaultImpl(metaclass=ImplMeta):
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
)
+ def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
+ # order of col matters in an index
+ return tuple(col.name for col in index.columns)
+
+ def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
+ conn_indexes_by_name = {c.name: c for c in conn_indexes}
+
+ for idx in list(metadata_indexes):
+ if idx.name in conn_indexes_by_name:
+ continue
+ iex = sqla_compat.is_expression_index(idx)
+ if iex:
+ util.warn(
+ "autogenerate skipping metadata-specified "
+ "expression-based index "
+ f"{idx.name!r}; dialect {self.__dialect__!r} under "
+ f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't "
+ "reflect these indexes so they can't be compared"
+ )
+ metadata_indexes.discard(idx)
+
def _compare_identity_options(
attributes, metadata_io, inspector_io, default_io
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index e7c85bd..994d7cb 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import Column
+from sqlalchemy import Index
from sqlalchemy import literal_column
from sqlalchemy import Numeric
from sqlalchemy import text
@@ -21,7 +22,6 @@ from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.elements import ColumnClause
-from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.types import NULLTYPE
from .base import alter_column
@@ -235,8 +235,6 @@ class PostgresqlImpl(DefaultImpl):
metadata_indexes,
):
- conn_indexes_by_name = {c.name: c for c in conn_indexes}
-
doubled_constraints = {
index
for index in conn_indexes
@@ -246,23 +244,41 @@ class PostgresqlImpl(DefaultImpl):
for ix in doubled_constraints:
conn_indexes.remove(ix)
- for idx in list(metadata_indexes):
- if idx.name in conn_indexes_by_name:
- continue
- exprs = idx.expressions
- for expr in exprs:
- while isinstance(expr, UnaryExpression):
- expr = expr.element
- if not isinstance(expr, Column):
- if sqla_compat.sqla_2:
- msg = ""
- else:
- msg = "; not supported by SQLAlchemy reflection"
- util.warn(
- "autogenerate skipping functional index "
- f"{idx.name!r}{msg}"
- )
- metadata_indexes.discard(idx)
+ if not sqla_compat.sqla_2:
+ self._skip_functional_indexes(metadata_indexes, conn_indexes)
+
+ def _cleanup_index_expr(self, index: Index, expr: str) -> str:
+ # start = expr
+ expr = expr.lower()
+ expr = expr.replace('"', "")
+ if index.table is not None:
+ expr = expr.replace(f"{index.table.name.lower()}.", "")
+
+ while expr and expr[0] == "(" and expr[-1] == ")":
+ expr = expr[1:-1]
+ if "::" in expr:
+ # strip :: cast. types can have spaces in them
+ expr = re.sub(r"(::[\w ]+\w)", "", expr)
+
+ # print(f"START: {start} END: {expr}")
+ return expr
+
+ def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
+ if sqla_compat.is_expression_index(index):
+ return tuple(
+ self._cleanup_index_expr(
+ index,
+ e
+ if isinstance(e, str)
+ else e.compile(
+ dialect=self.dialect,
+ compile_kwargs={"literal_binds": True},
+ ).string,
+ )
+ for e in index.expressions
+ )
+ else:
+ return super().create_index_sig(index)
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext
diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py
index 6b8c293..302a877 100644
--- a/alembic/ddl/sqlite.py
+++ b/alembic/ddl/sqlite.py
@@ -186,6 +186,16 @@ class SQLiteImpl(DefaultImpl):
existing_transfer["expr"], new_type
)
+ def correct_for_autogen_constraints(
+ self,
+ conn_unique_constraints,
+ conn_indexes,
+ metadata_unique_constraints,
+ metadata_indexes,
+ ):
+
+ self._skip_functional_indexes(metadata_indexes, conn_indexes)
+
@compiles(RenameTable, "sqlite")
def visit_rename_table(
diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py
index 6896276..989c7cd 100644
--- a/alembic/testing/requirements.py
+++ b/alembic/testing/requirements.py
@@ -96,6 +96,13 @@ class SuiteRequirements(Requirements):
)
@property
+ def sqlalchemy_2(self):
+ return exclusions.skip_if(
+ lambda config: not util.sqla_2,
+ "SQLAlchemy 2.x test",
+ )
+
+ @property
def comments(self):
return exclusions.only_if(
lambda config: config.db.dialect.supports_comments
diff --git a/alembic/testing/schemacompare.py b/alembic/testing/schemacompare.py
index 4409421..c063499 100644
--- a/alembic/testing/schemacompare.py
+++ b/alembic/testing/schemacompare.py
@@ -43,15 +43,19 @@ class CompareColumn:
class CompareIndex:
- def __init__(self, index):
+ def __init__(self, index, name_only=False):
self.index = index
+ self.name_only = name_only
def __eq__(self, other):
- return (
- str(schema.CreateIndex(self.index))
- == str(schema.CreateIndex(other))
- and self.index.dialect_kwargs == other.dialect_kwargs
- )
+ if self.name_only:
+ return self.index.name == other.name
+ else:
+ return (
+ str(schema.CreateIndex(self.index))
+ == str(schema.CreateIndex(other))
+ and self.index.dialect_kwargs == other.dialect_kwargs
+ )
def __ne__(self, other):
return not self.__eq__(other)
diff --git a/alembic/testing/util.py b/alembic/testing/util.py
index a82690d..e65597d 100644
--- a/alembic/testing/util.py
+++ b/alembic/testing/util.py
@@ -10,6 +10,8 @@ import re
import types
from typing import Union
+from sqlalchemy.util import inspect_getfullargspec
+
def flag_combinations(*combinations):
"""A facade around @testing.combinations() oriented towards boolean
@@ -69,10 +71,12 @@ def resolve_lambda(__fn, **kw):
"""
+ pos_args = inspect_getfullargspec(__fn)[0]
+ pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
glb = dict(__fn.__globals__)
glb.update(kw)
new_fn = types.FunctionType(__fn.__code__, glb)
- return new_fn()
+ return new_fn(**pass_pos_args)
def metadata_fixture(ddl="function"):
diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py
index 4374f46..c81d431 100644
--- a/alembic/util/__init__.py
+++ b/alembic/util/__init__.py
@@ -28,6 +28,7 @@ from .sqla_compat import has_computed
from .sqla_compat import sqla_13
from .sqla_compat import sqla_14
from .sqla_compat import sqla_1x
+from .sqla_compat import sqla_2
if not sqla_13:
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index 9bdcfc3..738bbcb 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -23,8 +23,10 @@ from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql import visitors
from sqlalchemy.sql.elements import BindParameter
+from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.visitors import traverse
from typing_extensions import TypeGuard
@@ -38,7 +40,6 @@ if TYPE_CHECKING:
from sqlalchemy.sql.base import ColumnCollection
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.dml import Insert
- from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import SchemaItem
@@ -61,7 +62,8 @@ _vers = tuple(
sqla_13 = _vers >= (1, 3)
sqla_14 = _vers >= (1, 4)
sqla_14_26 = _vers >= (1, 4, 26)
-sqla_2 = _vers >= (1, 5)
+sqla_2 = _vers >= (2,)
+sqlalchemy_version = __version__
if sqla_14:
@@ -573,3 +575,13 @@ else:
def _select(*columns, **kw) -> Select: # type: ignore[no-redef]
return sql.select(list(columns), **kw) # type: ignore[call-overload]
+
+
+def is_expression_index(index: Index) -> bool:
+ expr: Any
+ for expr in index.expressions:
+ while isinstance(expr, UnaryExpression):
+ expr = expr.element
+ if not isinstance(expr, ColumnClause) or expr.is_literal:
+ return True
+ return False