summaryrefslogtreecommitdiff
path: root/alembic/autogenerate/compare.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/autogenerate/compare.py')
-rw-r--r--alembic/autogenerate/compare.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 828a4cd..d97eb9c 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -17,6 +17,7 @@ from typing import Union
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy import schema as sa_schema
+from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.util import OrderedSet
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
@@ -37,6 +39,7 @@ if TYPE_CHECKING:
from sqlalchemy.sql.schema import UniqueConstraint
from alembic.autogenerate.api import AutogenContext
+ from alembic.ddl.impl import DefaultImpl
from alembic.operations.ops import AlterColumnOp
from alembic.operations.ops import MigrationScript
from alembic.operations.ops import ModifyTableOps
@@ -276,14 +279,12 @@ def _compare_tables(
def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]:
- exprs = []
- for col_name in params["column_names"]:
+ exprs: list[Union[Column[Any], TextClause]] = []
+ for num, col_name in enumerate(params["column_names"]):
+ item: Union[Column[Any], TextClause]
if col_name is None:
- util.warn(
- "Skipping reflected expression-based "
- f"index {params['name']!r}"
- )
- return None
+ assert "expressions" in params
+ item = text(params["expressions"][num])
else:
item = conn_table.c[col_name]
exprs.append(item)
@@ -439,10 +440,10 @@ class _uq_constraint_sig(_constraint_sig):
class _ix_constraint_sig(_constraint_sig):
is_index = True
- def __init__(self, const: Index) -> None:
+ def __init__(self, const: Index, impl: DefaultImpl) -> None:
self.const = const
self.name = const.name
- self.sig = tuple(sorted([col.name for col in const.columns]))
+ self.sig = impl.create_index_sig(const)
self.is_unique = bool(const.unique)
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
@@ -624,11 +625,14 @@ def _compare_indexes_and_uniques(
_uq_constraint_sig(uq) for uq in metadata_unique_constraints
}
- metadata_indexes_sig = {_ix_constraint_sig(ix) for ix in metadata_indexes}
+ impl = autogen_context.migration_context.impl
+ metadata_indexes_sig = {
+ _ix_constraint_sig(ix, impl) for ix in metadata_indexes
+ }
conn_unique_constraints = {_uq_constraint_sig(uq) for uq in conn_uniques}
- conn_indexes_sig = {_ix_constraint_sig(ix) for ix in conn_indexes}
+ conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes}
# 5. index things by name, for those objects that have names
metadata_names = {
@@ -816,7 +820,7 @@ def _compare_indexes_and_uniques(
)
if conn_obj.sig != metadata_obj.sig:
msg.append(
- " columns %r to %r" % (conn_obj.sig, metadata_obj.sig)
+ " expression %r to %r" % (conn_obj.sig, metadata_obj.sig)
)
if msg: