diff options
Diffstat (limited to 'alembic')
-rw-r--r-- | alembic/autogenerate/compare.py | 28 | ||||
-rw-r--r-- | alembic/ddl/impl.py | 21 | ||||
-rw-r--r-- | alembic/ddl/postgresql.py | 56 | ||||
-rw-r--r-- | alembic/ddl/sqlite.py | 10 | ||||
-rw-r--r-- | alembic/testing/requirements.py | 7 | ||||
-rw-r--r-- | alembic/testing/schemacompare.py | 16 | ||||
-rw-r--r-- | alembic/testing/util.py | 6 | ||||
-rw-r--r-- | alembic/util/__init__.py | 1 | ||||
-rw-r--r-- | alembic/util/sqla_compat.py | 16 |
9 files changed, 120 insertions, 41 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index c2181b8..f50c41c 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -17,6 +17,7 @@ from typing import Union from sqlalchemy import event from sqlalchemy import inspect from sqlalchemy import schema as sa_schema +from sqlalchemy import text from sqlalchemy import types as sqltypes from sqlalchemy.util import OrderedSet @@ -30,6 +31,7 @@ if TYPE_CHECKING: from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index @@ -37,6 +39,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.schema import UniqueConstraint from alembic.autogenerate.api import AutogenContext + from alembic.ddl.impl import DefaultImpl from alembic.operations.ops import AlterColumnOp from alembic.operations.ops import MigrationScript from alembic.operations.ops import ModifyTableOps @@ -276,14 +279,12 @@ def _compare_tables( def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]: - exprs = [] - for col_name in params["column_names"]: + exprs: list[Union[Column[Any], TextClause]] = [] + for num, col_name in enumerate(params["column_names"]): + item: Union[Column[Any], TextClause] if col_name is None: - util.warn( - "Skipping reflected expression-based " - f"index {params['name']!r}" - ) - return None + assert "expressions" in params + item = text(params["expressions"][num]) else: item = conn_table.c[col_name] exprs.append(item) @@ -439,10 +440,10 @@ class _uq_constraint_sig(_constraint_sig): class _ix_constraint_sig(_constraint_sig): is_index = True - def __init__(self, const: Index) -> None: + def __init__(self, const: Index, impl: DefaultImpl) -> None: self.const = const self.name = const.name - self.sig = tuple(sorted([col.name for col in const.columns])) + self.sig = impl.create_index_sig(const) self.is_unique = bool(const.unique) def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: @@ -624,11 +625,14 @@ def _compare_indexes_and_uniques( _uq_constraint_sig(uq) for uq in metadata_unique_constraints } - metadata_indexes_sig = {_ix_constraint_sig(ix) for ix in metadata_indexes} + impl = autogen_context.migration_context.impl + metadata_indexes_sig = { + _ix_constraint_sig(ix, impl) for ix in metadata_indexes + } conn_unique_constraints = {_uq_constraint_sig(uq) for uq in conn_uniques} - conn_indexes_sig = {_ix_constraint_sig(ix) for ix in conn_indexes} + conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes} # 5. index things by name, for those objects that have names metadata_names = { @@ -815,7 +819,7 @@ def _compare_indexes_and_uniques( ) if conn_obj.sig != metadata_obj.sig: msg.append( - " columns %r to %r" % (conn_obj.sig, metadata_obj.sig) + " expression %r to %r" % (conn_obj.sig, metadata_obj.sig) ) if msg: diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 728d1da..f11d1ed 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -665,6 +665,27 @@ class DefaultImpl(metaclass=ImplMeta): bool(diff) or bool(metadata_identity) != bool(inspector_identity), ) + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: + # order of col matters in an index + return tuple(col.name for col in index.columns) + + def _skip_functional_indexes(self, metadata_indexes, conn_indexes): + conn_indexes_by_name = {c.name: c for c in conn_indexes} + + for idx in list(metadata_indexes): + if idx.name in conn_indexes_by_name: + continue + iex = sqla_compat.is_expression_index(idx) + if iex: + util.warn( + "autogenerate skipping metadata-specified " + "expression-based index " + f"{idx.name!r}; dialect {self.__dialect__!r} under " + f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't " + "reflect these indexes so they can't be compared" + ) + metadata_indexes.discard(idx) + def _compare_identity_options( attributes, metadata_io, inspector_io, default_io diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index e7c85bd..994d7cb 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING from typing import Union from sqlalchemy import Column +from sqlalchemy import Index from sqlalchemy import literal_column from sqlalchemy import Numeric from sqlalchemy import text @@ -21,7 +22,6 @@ from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.schema import CreateIndex from sqlalchemy.sql.elements import ColumnClause -from sqlalchemy.sql.elements import UnaryExpression from sqlalchemy.types import NULLTYPE from .base import alter_column @@ -235,8 +235,6 @@ class PostgresqlImpl(DefaultImpl): metadata_indexes, ): - conn_indexes_by_name = {c.name: c for c in conn_indexes} - doubled_constraints = { index for index in conn_indexes @@ -246,23 +244,41 @@ class PostgresqlImpl(DefaultImpl): for ix in doubled_constraints: conn_indexes.remove(ix) - for idx in list(metadata_indexes): - if idx.name in conn_indexes_by_name: - continue - exprs = idx.expressions - for expr in exprs: - while isinstance(expr, UnaryExpression): - expr = expr.element - if not isinstance(expr, Column): - if sqla_compat.sqla_2: - msg = "" - else: - msg = "; not supported by SQLAlchemy reflection" - util.warn( - "autogenerate skipping functional index " - f"{idx.name!r}{msg}" - ) - metadata_indexes.discard(idx) + if not sqla_compat.sqla_2: + self._skip_functional_indexes(metadata_indexes, conn_indexes) + + def _cleanup_index_expr(self, index: Index, expr: str) -> str: + # start = expr + expr = expr.lower() + expr = expr.replace('"', "") + if index.table is not None: + expr = expr.replace(f"{index.table.name.lower()}.", "") + + while expr and expr[0] == "(" and expr[-1] == ")": + expr = expr[1:-1] + if "::" in expr: + # strip :: cast. types can have spaces in them + expr = re.sub(r"(::[\w ]+\w)", "", expr) + + # print(f"START: {start} END: {expr}") + return expr + + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: + if sqla_compat.is_expression_index(index): + return tuple( + self._cleanup_index_expr( + index, + e + if isinstance(e, str) + else e.compile( + dialect=self.dialect, + compile_kwargs={"literal_binds": True}, + ).string, + ) + for e in index.expressions + ) + else: + return super().create_index_sig(index) def render_type( self, type_: TypeEngine, autogen_context: AutogenContext diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index 6b8c293..302a877 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -186,6 +186,16 @@ class SQLiteImpl(DefaultImpl): existing_transfer["expr"], new_type ) + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): + + self._skip_functional_indexes(metadata_indexes, conn_indexes) + @compiles(RenameTable, "sqlite") def visit_rename_table( diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 6896276..989c7cd 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -96,6 +96,13 @@ class SuiteRequirements(Requirements): ) @property + def sqlalchemy_2(self): + return exclusions.skip_if( + lambda config: not util.sqla_2, + "SQLAlchemy 2.x test", + ) + + @property def comments(self): return exclusions.only_if( lambda config: config.db.dialect.supports_comments diff --git a/alembic/testing/schemacompare.py b/alembic/testing/schemacompare.py index 4409421..c063499 100644 --- a/alembic/testing/schemacompare.py +++ b/alembic/testing/schemacompare.py @@ -43,15 +43,19 @@ class CompareColumn: class CompareIndex: - def __init__(self, index): + def __init__(self, index, name_only=False): self.index = index + self.name_only = name_only def __eq__(self, other): - return ( - str(schema.CreateIndex(self.index)) - == str(schema.CreateIndex(other)) - and self.index.dialect_kwargs == other.dialect_kwargs - ) + if self.name_only: + return self.index.name == other.name + else: + return ( + str(schema.CreateIndex(self.index)) + == str(schema.CreateIndex(other)) + and self.index.dialect_kwargs == other.dialect_kwargs + ) def __ne__(self, other): return not self.__eq__(other) diff --git a/alembic/testing/util.py b/alembic/testing/util.py index a82690d..e65597d 100644 --- a/alembic/testing/util.py +++ b/alembic/testing/util.py @@ -10,6 +10,8 @@ import re import types from typing import Union +from sqlalchemy.util import inspect_getfullargspec + def flag_combinations(*combinations): """A facade around @testing.combinations() oriented towards boolean @@ -69,10 +71,12 @@ def resolve_lambda(__fn, **kw): """ + pos_args = inspect_getfullargspec(__fn)[0] + pass_pos_args = {arg: kw.pop(arg) for arg in pos_args} glb = dict(__fn.__globals__) glb.update(kw) new_fn = types.FunctionType(__fn.__code__, glb) - return new_fn() + return new_fn(**pass_pos_args) def metadata_fixture(ddl="function"): diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index 4374f46..c81d431 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -28,6 +28,7 @@ from .sqla_compat import has_computed from .sqla_compat import sqla_13 from .sqla_compat import sqla_14 from .sqla_compat import sqla_1x +from .sqla_compat import sqla_2 if not sqla_13: diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 9bdcfc3..738bbcb 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -23,8 +23,10 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKeyConstraint from sqlalchemy.sql import visitors from sqlalchemy.sql.elements import BindParameter +from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.elements import UnaryExpression from sqlalchemy.sql.visitors import traverse from typing_extensions import TypeGuard @@ -38,7 +40,6 @@ if TYPE_CHECKING: from sqlalchemy.sql.base import ColumnCollection from sqlalchemy.sql.compiler import SQLCompiler from sqlalchemy.sql.dml import Insert - from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import SchemaItem @@ -61,7 +62,8 @@ _vers = tuple( sqla_13 = _vers >= (1, 3) sqla_14 = _vers >= (1, 4) sqla_14_26 = _vers >= (1, 4, 26) -sqla_2 = _vers >= (1, 5) +sqla_2 = _vers >= (2,) +sqlalchemy_version = __version__ if sqla_14: @@ -573,3 +575,13 @@ else: def _select(*columns, **kw) -> Select: # type: ignore[no-redef] return sql.select(list(columns), **kw) # type: ignore[call-overload] + + +def is_expression_index(index: Index) -> bool: + expr: Any + for expr in index.expressions: + while isinstance(expr, UnaryExpression): + expr = expr.element + if not isinstance(expr, ColumnClause) or expr.is_literal: + return True + return False |