summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--alembic/operations/batch.py22
-rw-r--r--alembic/util/compat.py4
-rw-r--r--alembic/util/sqla_compat.py6
-rw-r--r--docs/build/unreleased/1034.rst6
-rw-r--r--tests/test_batch.py32
5 files changed, 60 insertions, 10 deletions
diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py
index 308bc2e..7c7de9f 100644
--- a/alembic/operations/batch.py
+++ b/alembic/operations/batch.py
@@ -24,8 +24,10 @@ from sqlalchemy.util import topological
from ..util import exc
from ..util.sqla_compat import _columns_for_constraint
from ..util.sqla_compat import _copy
+from ..util.sqla_compat import _copy_expression
from ..util.sqla_compat import _ensure_scope_for_ddl
from ..util.sqla_compat import _fk_is_self_referential
+from ..util.sqla_compat import _idx_table_bound_expressions
from ..util.sqla_compat import _insert_inline
from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import _remove_column_from_collection
@@ -354,7 +356,25 @@ class ApplyBatchImpl:
def _gather_indexes_from_both_tables(self) -> List["Index"]:
assert self.new_table is not None
idx: List[Index] = []
- idx.extend(self.indexes.values())
+
+ for idx_existing in self.indexes.values():
+ # this is a lift-and-move from Table.to_metadata
+
+ if idx_existing._column_flag: # type: ignore
+ continue
+
+ idx_copy = Index(
+ idx_existing.name,
+ unique=idx_existing.unique,
+ *[
+ _copy_expression(expr, self.new_table)
+ for expr in _idx_table_bound_expressions(idx_existing)
+ ],
+ _table=self.new_table,
+ **idx_existing.kwargs,
+ )
+ idx.append(idx_copy)
+
for index in self.new_indexes.values():
idx.append(
Index(
diff --git a/alembic/util/compat.py b/alembic/util/compat.py
index cabff6e..289aaa2 100644
--- a/alembic/util/compat.py
+++ b/alembic/util/compat.py
@@ -35,9 +35,9 @@ else:
def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
ep = importlib_metadata.entry_points()
if hasattr(ep, "select"):
- return ep.select(group=group) # type:ignore[attr-defined]
+ return ep.select(group=group) # type: ignore
else:
- return ep.get(group, ())
+ return ep.get(group, ()) # type: ignore
def formatannotation_fwdref(annotation, base_module=None):
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index 21a9f7f..9c06e85 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -2,6 +2,8 @@ from __future__ import annotations
import contextlib
import re
+from typing import Any
+from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional
@@ -158,6 +160,10 @@ def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool:
return in_transaction()
+def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
+ return idx.expressions # type: ignore
+
+
def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
return schema_item._copy(**kw) # type: ignore[union-attr]
diff --git a/docs/build/unreleased/1034.rst b/docs/build/unreleased/1034.rst
new file mode 100644
index 0000000..558c1ef
--- /dev/null
+++ b/docs/build/unreleased/1034.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: bug, batch
+ :tickets: 1034
+
+ Fixed issue in batch mode where CREATE INDEX would not use a new column
+ name in the case of a column rename.
diff --git a/tests/test_batch.py b/tests/test_batch.py
index f0bbd75..b19fa98 100644
--- a/tests/test_batch.py
+++ b/tests/test_batch.py
@@ -281,16 +281,20 @@ class BatchApplyTest(TestBase):
create_stmt = re.sub(r"[\n\t]", "", create_stmt)
idx_stmt = ""
- for idx in impl.indexes.values():
- idx_stmt += str(CreateIndex(idx).compile(dialect=context.dialect))
- for idx in impl.new_indexes.values():
- impl.new_table.name = impl.table.name
+
+ # create indexes; these should be created in terms of the
+ # final table name
+ impl.new_table.name = impl.table.name
+
+ for idx in impl._gather_indexes_from_both_tables():
idx_stmt += str(CreateIndex(idx).compile(dialect=context.dialect))
- impl.new_table.name = ApplyBatchImpl._calc_temp_name(
- impl.table.name
- )
+
idx_stmt = re.sub(r"[\n\t]", "", idx_stmt)
+ # revert new table name to the temp name, assertions below
+ # are looking for the temp name
+ impl.new_table.name = ApplyBatchImpl._calc_temp_name(impl.table.name)
+
if ddl_contains:
assert ddl_contains in create_stmt + idx_stmt
if ddl_not_contains:
@@ -357,6 +361,20 @@ class BatchApplyTest(TestBase):
new_table = self._assert_impl(impl)
eq_(new_table.c.x.name, "q")
+ def test_rename_col_w_index(self):
+ impl = self._ix_fixture()
+ impl.alter_column("tname", "y", name="y2")
+ new_table = self._assert_impl(
+ impl, ddl_contains="CREATE INDEX ix1 ON tname (y2)"
+ )
+ eq_(new_table.c.y.name, "y2")
+
+ def test_rename_col_w_uq(self):
+ impl = self._uq_fixture()
+ impl.alter_column("tname", "y", name="y2")
+ new_table = self._assert_impl(impl, ddl_contains="UNIQUE (y2)")
+ eq_(new_table.c.y.name, "y2")
+
def test_alter_column_comment(self):
impl = self._simple_fixture()
impl.alter_column("tname", "x", comment="some comment")