diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 96 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pg_catalog.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/reflection.py | 77 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 31 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_reflection.py | 58 |
8 files changed, 201 insertions, 82 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 20903b55f..8b89cdee2 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -4060,11 +4060,23 @@ class PGDialect(default.DefaultDialect): select( idx_sq.c.indexrelid, idx_sq.c.indrelid, - pg_catalog.pg_attribute.c.attname, + # NOTE: always using pg_get_indexdef is too slow so just + # invoke when the element is an expression + sql.case( + ( + idx_sq.c.attnum == 0, + pg_catalog.pg_get_indexdef( + idx_sq.c.indexrelid, idx_sq.c.ord + 1, True + ), + ), + else_=pg_catalog.pg_attribute.c.attname, + ).label("element"), + (idx_sq.c.attnum == 0).label("is_expr"), ) - .select_from(pg_catalog.pg_attribute) - .join( - idx_sq, + .select_from(idx_sq) + .outerjoin( + # do not remove rows where idx_sq.c.attnum is 0 + pg_catalog.pg_attribute, sql.and_( pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum, pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, @@ -4079,7 +4091,10 @@ class PGDialect(default.DefaultDialect): select( attr_sq.c.indexrelid, attr_sq.c.indrelid, - sql.func.array_agg(attr_sq.c.attname).label("cols"), + sql.func.array_agg(attr_sq.c.element).label("elements"), + sql.func.array_agg(attr_sq.c.is_expr).label( + "elements_is_expr" + ), ) .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid) .subquery("idx_cols") @@ -4095,19 +4110,27 @@ class PGDialect(default.DefaultDialect): pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname.label("relname_index"), pg_catalog.pg_index.c.indisunique, - pg_catalog.pg_index.c.indexprs, pg_catalog.pg_constraint.c.conrelid.is_not(None).label( "has_constraint" ), pg_catalog.pg_index.c.indoption, pg_class_index.c.reloptions, pg_catalog.pg_am.c.amname, - pg_catalog.pg_get_expr( - pg_catalog.pg_index.c.indpred, - pg_catalog.pg_index.c.indrelid, + sql.case( + # pg_get_expr is very fast so this case has almost no + # performance impact + ( + pg_catalog.pg_index.c.indpred.is_not(None), + pg_catalog.pg_get_expr( + pg_catalog.pg_index.c.indpred, + pg_catalog.pg_index.c.indrelid, + ), + ), + else_=sql.null(), ).label("filter_definition"), indnkeyatts, - cols_sq.c.cols.label("index_cols"), + cols_sq.c.elements, + cols_sq.c.elements_is_expr, ) .select_from(pg_catalog.pg_index) .where( @@ -4178,38 +4201,43 @@ class PGDialect(default.DefaultDialect): table_indexes = indexes[(schema, table_name)] - if row["indexprs"]: - tn = ( - table_name - if schema is None - else f"{schema}.{table_name}" - ) - util.warn( - "Skipped unsupported reflection of " - f"expression-based index {index_name} of " - f"table {tn}" - ) - continue - - all_cols = row["index_cols"] + all_elements = row["elements"] + all_elements_is_expr = row["elements_is_expr"] indnkeyatts = row["indnkeyatts"] # "The number of key columns in the index, not counting any # included columns, which are merely stored and do not # participate in the index semantics" - if indnkeyatts and all_cols[indnkeyatts:]: + if indnkeyatts and len(all_elements) > indnkeyatts: # this is a "covering index" which has INCLUDE columns # as well as regular index columns - inc_cols = all_cols[indnkeyatts:] - idx_cols = all_cols[:indnkeyatts] + inc_cols = all_elements[indnkeyatts:] + idx_elements = all_elements[:indnkeyatts] + idx_elements_is_expr = all_elements_is_expr[ + :indnkeyatts + ] + # postgresql does not support expression on included + # columns as of v14: "ERROR: expressions are not + # supported in included columns". + assert all( + not is_expr + for is_expr in all_elements_is_expr[indnkeyatts:] + ) else: - idx_cols = all_cols + idx_elements = all_elements + idx_elements_is_expr = all_elements_is_expr inc_cols = [] - index = { - "name": index_name, - "unique": row["indisunique"], - "column_names": idx_cols, - } + index = {"name": index_name, "unique": row["indisunique"]} + if any(idx_elements_is_expr): + index["column_names"] = [ + None if is_expr else expr + for expr, is_expr in zip( + idx_elements, idx_elements_is_expr + ) + ] + index["expressions"] = idx_elements + else: + index["column_names"] = idx_elements sorting = {} for col_index, col_flags in enumerate(row["indoption"]): @@ -4224,7 +4252,7 @@ class PGDialect(default.DefaultDialect): if col_flags & 0x02: col_sorting += ("nulls_first",) if col_sorting: - sorting[idx_cols[col_index]] = col_sorting + sorting[idx_elements[col_index]] = col_sorting if sorting: index["column_sorting"] = sorting if row["has_constraint"]: diff --git a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py index a77e7ccf6..ed8926a26 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py +++ b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -67,6 +67,7 @@ pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence format_type = _pg_cat.format_type pg_get_expr = _pg_cat.pg_get_expr pg_get_constraintdef = _pg_cat.pg_get_constraintdef +pg_get_indexdef = _pg_cat.pg_get_indexdef # constants RELKINDS_TABLE_NO_FOREIGN = ("r", "p") diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 004ce2993..208c4f6b0 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -472,8 +472,17 @@ class ReflectedIndex(TypedDict): name: Optional[str] """index name""" - column_names: List[str] - """column names which the index refers towards""" + column_names: List[Optional[str]] + """column names which the index refers towards. + An element of this list is ``None`` if it's an expression and is + returned in the ``expressions`` list. + """ + + expressions: NotRequired[List[str]] + """Expressions that compose the index. This list, when present, contains + both plain column names (that are also in ``column_names``) and + expressions (that are ``None`` in ``column_names``). + """ unique: bool """whether or not the index has a unique flag""" diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index f78ca84a2..c3c5ff5a8 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -1796,12 +1796,12 @@ class Inspector(inspection.Inspectable["Inspector"]): ) ) - _index_sort_exprs = [ - ("asc", operators.asc_op), - ("desc", operators.desc_op), - ("nulls_first", operators.nulls_first_op), - ("nulls_last", operators.nulls_last_op), - ] + _index_sort_exprs = { + "asc": operators.asc_op, + "desc": operators.desc_op, + "nulls_first": operators.nulls_first_op, + "nulls_last": operators.nulls_last_op, + } def _reflect_indexes( self, @@ -1818,6 +1818,7 @@ class Inspector(inspection.Inspectable["Inspector"]): for index_d in indexes: name = index_d["name"] columns = index_d["column_names"] + expressions = index_d.get("expressions") column_sorting = index_d.get("column_sorting", {}) unique = index_d["unique"] flavor = index_d.get("type", "index") @@ -1830,33 +1831,43 @@ class Inspector(inspection.Inspectable["Inspector"]): continue # look for columns by orig name in cols_by_orig_name, # but support columns that are in-Python only as fallback - idx_col: Any - idx_cols = [] - for c in columns: - try: - idx_col = ( - cols_by_orig_name[c] - if c in cols_by_orig_name - else table.c[c] - ) - except KeyError: - util.warn( - "%s key '%s' was not located in " - "columns for table '%s'" % (flavor, c, table.name) - ) - continue - c_sorting = column_sorting.get(c, ()) - for k, op in self._index_sort_exprs: - if k in c_sorting: - idx_col = op(idx_col) - idx_cols.append(idx_col) - - sa_schema.Index( - name, - *idx_cols, - _table=table, - **dict(list(dialect_options.items()) + [("unique", unique)]), - ) + idx_element: Any + idx_elements = [] + for index, c in enumerate(columns): + if c is None: + if not expressions: + util.warn( + f"Skipping {flavor} {name!r} because key " + f"{index+1} reflected as None but no " + "'expressions' were returned" + ) + break + idx_element = sql.text(expressions[index]) + else: + try: + if c in cols_by_orig_name: + idx_element = cols_by_orig_name[c] + else: + idx_element = table.c[c] + except KeyError: + util.warn( + f"{flavor} key {c!r} was not located in " + f"columns for table {table.name!r}" + ) + continue + for option in column_sorting.get(c, ()): + if option in self._index_sort_exprs: + op = self._index_sort_exprs[option] + idx_element = op(idx_element) + idx_elements.append(idx_element) + else: + sa_schema.Index( + name, + *idx_elements, + _table=table, + unique=unique, + **dialect_options, + ) def _reflect_unique_constraints( self, diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 8c3c4bc27..0c83cb469 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -18,6 +18,7 @@ from .assertions import assert_warns from .assertions import assert_warns_message from .assertions import AssertsCompiledSQL from .assertions import AssertsExecutionResults +from .assertions import ComparesIndexes from .assertions import ComparesTables from .assertions import emits_warning from .assertions import emits_warning_on diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 937706363..44e7e892f 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -838,3 +838,34 @@ class AssertsExecutionResults: def assert_statement_count(self, db, count): return self.assert_execution(db, assertsql.CountStatements(count)) + + +class ComparesIndexes: + def compare_table_index_with_expected( + self, table: schema.Table, expected: list, dialect_name: str + ): + eq_(len(table.indexes), len(expected)) + idx_dict = {idx.name: idx for idx in table.indexes} + for exp in expected: + idx = idx_dict[exp["name"]] + eq_(idx.unique, exp["unique"]) + cols = [c for c in exp["column_names"] if c is not None] + eq_(len(idx.columns), len(cols)) + for c in cols: + is_true(c in idx.columns) + exprs = exp.get("expressions") + if exprs: + eq_(len(idx.expressions), len(exprs)) + for idx_exp, expr, col in zip( + idx.expressions, exprs, exp["column_names"] + ): + if col is None: + eq_(idx_exp.text, expr) + if ( + exp.get("dialect_options") + and f"{dialect_name}_include" in exp["dialect_options"] + ): + eq_( + idx.dialect_options[dialect_name]["include"], + exp["dialect_options"][f"{dialect_name}_include"], + ) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index cb955ff3d..55b10bdd5 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -736,6 +736,12 @@ class SuiteRequirements(Requirements): return exclusions.closed() @property + def reflect_indexes_with_expressions(self): + """target database supports reflection of indexes with + SQL expressions.""" + return exclusions.closed() + + @property def unique_constraint_reflection(self): """target dialect supports reflection of unique constraints""" return exclusions.open() diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 6c71696a0..a3737a91a 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -35,6 +35,7 @@ from ...schema import DDL from ...schema import Index from ...sql.elements import quoted_name from ...sql.schema import BLANK_SCHEMA +from ...testing import ComparesIndexes from ...testing import ComparesTables from ...testing import is_false from ...testing import is_true @@ -2254,7 +2255,7 @@ class TableNoColumnsTest(fixtures.TestBase): eq_(multi, {(None, "empty_v"): []}) -class ComponentReflectionTestExtra(fixtures.TestBase): +class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): __backend__ = True @@ -2322,9 +2323,10 @@ class ComponentReflectionTestExtra(fixtures.TestBase): metadata, Column("x", String(30)), Column("y", String(30)), + Column("z", String(30)), ) - Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) + Index("t_idx", func.lower(t.c.x), t.c.z, func.lower(t.c.y)) Index("t_idx_2", t.c.x) @@ -2335,19 +2337,49 @@ class ComponentReflectionTestExtra(fixtures.TestBase): expected = [ {"name": "t_idx_2", "column_names": ["x"], "unique": False} ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] - expected[0]["dialect_options"] = { - "%s_include" % connection.engine.name: [] + + def completeIndex(entry): + if testing.requires.index_reflects_included_columns.enabled: + entry["include_columns"] = [] + entry["dialect_options"] = { + f"{connection.engine.name}_include": [] + } + + completeIndex(expected[0]) + + class filtering_str(str): + def __eq__(self, other): + # test that lower and x or y are in the string + return "lower" in other and ("x" in other or "y" in other) + + if testing.requires.reflect_indexes_with_expressions.enabled: + expr_index = { + "name": "t_idx", + "column_names": [None, "z", None], + "expressions": [ + filtering_str("lower(x)"), + "z", + filtering_str("lower(y)"), + ], + "unique": False, } + completeIndex(expr_index) + expected.insert(0, expr_index) + eq_(insp.get_indexes("t"), expected) + m2 = MetaData() + t2 = Table("t", m2, autoload_with=connection) + else: + with expect_warnings( + "Skipped unsupported reflection of expression-based " + "index t_idx" + ): + eq_(insp.get_indexes("t"), expected) + m2 = MetaData() + t2 = Table("t", m2, autoload_with=connection) - with expect_warnings( - "Skipped unsupported reflection of expression-based index t_idx" - ): - eq_( - insp.get_indexes("t"), - expected, - ) + self.compare_table_index_with_expected( + t2, expected, connection.engine.name + ) @testing.requires.index_reflects_included_columns def test_reflect_covering_index(self, metadata, connection): |