diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 937706363..44e7e892f 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -838,3 +838,34 @@ class AssertsExecutionResults: def assert_statement_count(self, db, count): return self.assert_execution(db, assertsql.CountStatements(count)) + + +class ComparesIndexes: + def compare_table_index_with_expected( + self, table: schema.Table, expected: list, dialect_name: str + ): + eq_(len(table.indexes), len(expected)) + idx_dict = {idx.name: idx for idx in table.indexes} + for exp in expected: + idx = idx_dict[exp["name"]] + eq_(idx.unique, exp["unique"]) + cols = [c for c in exp["column_names"] if c is not None] + eq_(len(idx.columns), len(cols)) + for c in cols: + is_true(c in idx.columns) + exprs = exp.get("expressions") + if exprs: + eq_(len(idx.expressions), len(exprs)) + for idx_exp, expr, col in zip( + idx.expressions, exprs, exp["column_names"] + ): + if col is None: + eq_(idx_exp.text, expr) + if ( + exp.get("dialect_options") + and f"{dialect_name}_include" in exp["dialect_options"] + ): + eq_( + idx.dialect_options[dialect_name]["include"], + exp["dialect_options"][f"{dialect_name}_include"], + ) |