diff options
Diffstat (limited to 'test/dialect/postgresql/test_reflection.py')
-rw-r--r-- | test/dialect/postgresql/test_reflection.py | 441 |
1 files changed, 343 insertions, 98 deletions
diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index cbb1809e4..00e5dc5b9 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -27,17 +27,21 @@ from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.dialects.postgresql import INTERVAL +from sqlalchemy.dialects.postgresql import pg_catalog from sqlalchemy.dialects.postgresql import TSRANGE +from sqlalchemy.engine import ObjectKind +from sqlalchemy.engine import ObjectScope from sqlalchemy.schema import CreateIndex from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock -from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_warns from sqlalchemy.testing.assertions import AssertsExecutionResults from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.assertions import expect_raises from sqlalchemy.testing.assertions import is_ +from sqlalchemy.testing.assertions import is_false from sqlalchemy.testing.assertions import is_true @@ -231,17 +235,36 @@ class MaterializedViewReflectionTest( connection.execute(target.insert(), {"id": 89, "data": "d1"}) materialized_view = sa.DDL( - "CREATE MATERIALIZED VIEW test_mview AS " "SELECT * FROM testtable" + "CREATE MATERIALIZED VIEW test_mview AS SELECT * FROM testtable" ) plain_view = sa.DDL( - "CREATE VIEW test_regview AS " "SELECT * FROM testtable" + "CREATE VIEW test_regview AS SELECT data FROM testtable" ) sa.event.listen(testtable, "after_create", plain_view) sa.event.listen(testtable, "after_create", materialized_view) sa.event.listen( testtable, + "after_create", + sa.DDL("COMMENT ON VIEW test_regview IS 'regular view comment'"), + ) + sa.event.listen( + testtable, + "after_create", + sa.DDL( + "COMMENT ON MATERIALIZED VIEW test_mview " + "IS 'materialized view comment'" + ), + ) + sa.event.listen( + testtable, + "after_create", + sa.DDL("CREATE INDEX mat_index ON test_mview(data DESC)"), + ) + + sa.event.listen( + testtable, "before_drop", sa.DDL("DROP MATERIALIZED VIEW test_mview"), ) @@ -249,6 +272,12 @@ class MaterializedViewReflectionTest( testtable, "before_drop", sa.DDL("DROP VIEW test_regview") ) + def test_has_type(self, connection): + insp = inspect(connection) + is_true(insp.has_type("test_mview")) + is_true(insp.has_type("test_regview")) + is_true(insp.has_type("testtable")) + def test_mview_is_reflected(self, connection): metadata = MetaData() table = Table("test_mview", metadata, autoload_with=connection) @@ -265,49 +294,99 @@ class MaterializedViewReflectionTest( def test_get_view_names(self, inspect_fixture): insp, conn = inspect_fixture - eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) + eq_(set(insp.get_view_names()), set(["test_regview"])) - def test_get_view_names_plain(self, connection): + def test_get_materialized_view_names(self, inspect_fixture): + insp, conn = inspect_fixture + eq_(set(insp.get_materialized_view_names()), set(["test_mview"])) + + def test_get_view_names_reflection_cache_ok(self, connection): insp = inspect(connection) + eq_(set(insp.get_view_names()), set(["test_regview"])) eq_( - set(insp.get_view_names(include=("plain",))), set(["test_regview"]) + set(insp.get_materialized_view_names()), + set(["test_mview"]), + ) + eq_( + set(insp.get_view_names()).union( + insp.get_materialized_view_names() + ), + set(["test_regview", "test_mview"]), ) - def test_get_view_names_plain_string(self, connection): + def test_get_view_definition(self, connection): insp = inspect(connection) - eq_(set(insp.get_view_names(include="plain")), set(["test_regview"])) - def test_get_view_names_materialized(self, connection): - insp = inspect(connection) + def normalize(definition): + return re.sub(r"[\n\t ]+", " ", definition.strip()) + eq_( - set(insp.get_view_names(include=("materialized",))), - set(["test_mview"]), + normalize(insp.get_view_definition("test_mview")), + "SELECT testtable.id, testtable.data FROM testtable;", + ) + eq_( + normalize(insp.get_view_definition("test_regview")), + "SELECT testtable.data FROM testtable;", ) - def test_get_view_names_reflection_cache_ok(self, connection): + def test_get_view_comment(self, connection): insp = inspect(connection) eq_( - set(insp.get_view_names(include=("plain",))), set(["test_regview"]) + insp.get_table_comment("test_regview"), + {"text": "regular view comment"}, ) eq_( - set(insp.get_view_names(include=("materialized",))), - set(["test_mview"]), + insp.get_table_comment("test_mview"), + {"text": "materialized view comment"}, ) - eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) - def test_get_view_names_empty(self, connection): + def test_get_multi_view_comment(self, connection): insp = inspect(connection) - assert_raises(ValueError, insp.get_view_names, include=()) + eq_( + insp.get_multi_table_comment(), + {(None, "testtable"): {"text": None}}, + ) + plain = {(None, "test_regview"): {"text": "regular view comment"}} + mat = {(None, "test_mview"): {"text": "materialized view comment"}} + eq_(insp.get_multi_table_comment(kind=ObjectKind.VIEW), plain) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.MATERIALIZED_VIEW), + mat, + ) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW), + {**plain, **mat}, + ) + eq_( + insp.get_multi_table_comment( + kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY + ), + {}, + ) - def test_get_view_definition(self, connection): + def test_get_multi_view_indexes(self, connection): insp = inspect(connection) + eq_(insp.get_multi_indexes(), {(None, "testtable"): []}) + + exp = { + "name": "mat_index", + "unique": False, + "column_names": ["data"], + "column_sorting": {"data": ("desc",)}, + } + if connection.dialect.server_version_info >= (11, 0): + exp["include_columns"] = [] + exp["dialect_options"] = {"postgresql_include": []} + plain = {(None, "test_regview"): []} + mat = {(None, "test_mview"): [exp]} + eq_(insp.get_multi_indexes(kind=ObjectKind.VIEW), plain) + eq_(insp.get_multi_indexes(kind=ObjectKind.MATERIALIZED_VIEW), mat) + eq_(insp.get_multi_indexes(kind=ObjectKind.ANY_VIEW), {**plain, **mat}) eq_( - re.sub( - r"[\n\t ]+", - " ", - insp.get_view_definition("test_mview").strip(), + insp.get_multi_indexes( + kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY ), - "SELECT testtable.id, testtable.data FROM testtable;", + {}, ) @@ -993,9 +1072,9 @@ class ReflectionTest( go, [ "Skipped unsupported reflection of " - "expression-based index idx1", + "expression-based index idx1 of table party", "Skipped unsupported reflection of " - "expression-based index idx3", + "expression-based index idx3 of table party", ], ) @@ -1016,7 +1095,7 @@ class ReflectionTest( metadata.create_all(connection) - ind = connection.dialect.get_indexes(connection, t1, None) + ind = connection.dialect.get_indexes(connection, t1.name, None) partial_definitions = [] for ix in ind: @@ -1337,6 +1416,9 @@ class ReflectionTest( } ], ) + is_true(inspector.has_type("mood", "test_schema")) + is_true(inspector.has_type("mood", "*")) + is_false(inspector.has_type("mood")) def test_inspect_enums(self, metadata, inspect_fixture): @@ -1345,30 +1427,49 @@ class ReflectionTest( enum_type = postgresql.ENUM( "cat", "dog", "rat", name="pet", metadata=metadata ) + enum_type.create(conn) + conn.commit() - with conn.begin(): - enum_type.create(conn) - - eq_( - inspector.get_enums(), - [ - { - "visible": True, - "labels": ["cat", "dog", "rat"], - "name": "pet", - "schema": "public", - } - ], - ) - - def test_get_table_oid(self, metadata, inspect_fixture): - - inspector, conn = inspect_fixture + res = [ + { + "visible": True, + "labels": ["cat", "dog", "rat"], + "name": "pet", + "schema": "public", + } + ] + eq_(inspector.get_enums(), res) + is_true(inspector.has_type("pet", "*")) + is_true(inspector.has_type("pet")) + is_false(inspector.has_type("pet", "test_schema")) + + enum_type.drop(conn) + conn.commit() + eq_(inspector.get_enums(), res) + is_true(inspector.has_type("pet")) + inspector.clear_cache() + eq_(inspector.get_enums(), []) + is_false(inspector.has_type("pet")) + + def test_get_table_oid(self, metadata, connection): + Table("t1", metadata, Column("col", Integer)) + Table("t1", metadata, Column("col", Integer), schema="test_schema") + metadata.create_all(connection) + insp = inspect(connection) + oid = insp.get_table_oid("t1") + oid_schema = insp.get_table_oid("t1", schema="test_schema") + is_true(isinstance(oid, int)) + is_true(isinstance(oid_schema, int)) + is_true(oid != oid_schema) - with conn.begin(): - Table("some_table", metadata, Column("q", Integer)).create(conn) + with expect_raises(exc.NoSuchTableError): + insp.get_table_oid("does_not_exist") - assert inspector.get_table_oid("some_table") is not None + metadata.tables["t1"].drop(connection) + eq_(insp.get_table_oid("t1"), oid) + insp.clear_cache() + with expect_raises(exc.NoSuchTableError): + insp.get_table_oid("t1") def test_inspect_enums_case_sensitive(self, metadata, connection): sa.event.listen( @@ -1707,77 +1808,146 @@ class ReflectionTest( ) def test_reflect_check_warning(self): - rows = [("some name", "NOTCHECK foobar")] + rows = [("foo", "some name", "NOTCHECK foobar")] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 + with testing.expect_warnings( + "Could not parse CHECK constraint text: 'NOTCHECK foobar'" ): - with testing.expect_warnings( - "Could not parse CHECK constraint text: 'NOTCHECK foobar'" - ): - testing.db.dialect.get_check_constraints(conn, "foo") + testing.db.dialect.get_check_constraints(conn, "foo") def test_reflect_extra_newlines(self): rows = [ - ("some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"), - ("some other name", "CHECK ((b\nIS\nNOT\nNULL))"), - ("some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"), - ("some name", "CHECK (c != 'hi\nim a name\n')"), + ("foo", "some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"), + ("foo", "some other name", "CHECK ((b\nIS\nNOT\nNULL))"), + ("foo", "some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"), + ("foo", "some name", "CHECK (c != 'hi\nim a name\n')"), ] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 - ): - check_constraints = testing.db.dialect.get_check_constraints( - conn, "foo" - ) - eq_( - check_constraints, - [ - { - "name": "some name", - "sqltext": "a \nIS\n NOT\n\n NULL\n", - }, - {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"}, - { - "name": "some CRLF name", - "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL", - }, - {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"}, - ], - ) + check_constraints = testing.db.dialect.get_check_constraints( + conn, "foo" + ) + eq_( + check_constraints, + [ + { + "name": "some name", + "sqltext": "a \nIS\n NOT\n\n NULL\n", + }, + {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"}, + { + "name": "some CRLF name", + "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL", + }, + {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"}, + ], + ) def test_reflect_with_not_valid_check_constraint(self): - rows = [("some name", "CHECK ((a IS NOT NULL)) NOT VALID")] + rows = [("foo", "some name", "CHECK ((a IS NOT NULL)) NOT VALID")] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 - ): - check_constraints = testing.db.dialect.get_check_constraints( - conn, "foo" + check_constraints = testing.db.dialect.get_check_constraints( + conn, "foo" + ) + eq_( + check_constraints, + [ + { + "name": "some name", + "sqltext": "a IS NOT NULL", + "dialect_options": {"not_valid": True}, + } + ], + ) + + def _apply_stm(self, connection, use_map): + if use_map: + return connection.execution_options( + schema_translate_map={ + None: "foo", + testing.config.test_schema: "bar", + } ) - eq_( - check_constraints, - [ - { - "name": "some name", - "sqltext": "a IS NOT NULL", - "dialect_options": {"not_valid": True}, - } - ], + else: + return connection + + @testing.combinations(True, False, argnames="use_map") + @testing.combinations(True, False, argnames="schema") + def test_schema_translate_map(self, metadata, connection, use_map, schema): + schema = testing.config.test_schema if schema else None + Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("a", Integer, index=True), + Column( + "b", + ForeignKey(f"{schema}.foo.id" if schema else "foo.id"), + unique=True, + ), + CheckConstraint("a>10", name="foo_check"), + comment="comm", + schema=schema, + ) + metadata.create_all(connection) + if use_map: + connection = connection.execution_options( + schema_translate_map={ + None: "foo", + testing.config.test_schema: "bar", + } ) + insp = inspect(connection) + eq_( + [c["name"] for c in insp.get_columns("foo", schema=schema)], + ["id", "a", "b"], + ) + eq_( + [ + i["column_names"] + for i in insp.get_indexes("foo", schema=schema) + ], + [["b"], ["a"]], + ) + eq_( + insp.get_pk_constraint("foo", schema=schema)[ + "constrained_columns" + ], + ["id"], + ) + eq_(insp.get_table_comment("foo", schema=schema), {"text": "comm"}) + eq_( + [ + f["constrained_columns"] + for f in insp.get_foreign_keys("foo", schema=schema) + ], + [["b"]], + ) + eq_( + [ + c["name"] + for c in insp.get_check_constraints("foo", schema=schema) + ], + ["foo_check"], + ) + eq_( + [ + u["column_names"] + for u in insp.get_unique_constraints("foo", schema=schema) + ], + [["b"]], + ) class CustomTypeReflectionTest(fixtures.TestBase): @@ -1804,9 +1974,23 @@ class CustomTypeReflectionTest(fixtures.TestBase): ("my_custom_type(ARG1)", ("ARG1", None)), ("my_custom_type(ARG1, ARG2)", ("ARG1", "ARG2")), ]: - column_info = dialect._get_column_info( - "colname", sch, None, False, {}, {}, "public", None, "", None + row_dict = { + "name": "colname", + "table_name": "tblname", + "format_type": sch, + "default": None, + "not_null": False, + "comment": None, + "generated": "", + "identity_options": None, + } + column_info = dialect._get_columns_info( + [row_dict], {}, {}, "public" ) + assert ("public", "tblname") in column_info + column_info = column_info[("public", "tblname")] + assert len(column_info) == 1 + column_info = column_info[0] assert isinstance(column_info["type"], self.CustomType) eq_(column_info["type"].arg1, args[0]) eq_(column_info["type"].arg2, args[1]) @@ -1951,3 +2135,64 @@ class IdentityReflectionTest(fixtures.TablesTest): exp = default.copy() exp.update(maxvalue=2**15 - 1) eq_(col["identity"], exp) + + +class TestReflectDifficultColTypes(fixtures.TablesTest): + __only_on__ = "postgresql" + __backend__ = True + + def define_tables(metadata): + Table( + "sample_table", + metadata, + Column("c1", Integer, primary_key=True), + Column("c2", Integer, unique=True), + Column("c3", Integer), + Index("sample_table_index", "c2", "c3"), + ) + + def check_int_list(self, row, key): + value = row[key] + is_true(isinstance(value, list)) + is_true(len(value) > 0) + is_true(all(isinstance(v, int) for v in value)) + + def test_pg_index(self, connection): + insp = inspect(connection) + + pgc_oid = insp.get_table_oid("sample_table") + cols = [ + col + for col in pg_catalog.pg_index.c + if testing.db.dialect.server_version_info + >= col.info.get("server_version", (0,)) + ] + + stmt = sa.select(*cols).filter_by(indrelid=pgc_oid) + rows = connection.execute(stmt).mappings().all() + is_true(len(rows) > 0) + cols = [ + col + for col in ["indkey", "indoption", "indclass", "indcollation"] + if testing.db.dialect.server_version_info + >= pg_catalog.pg_index.c[col].info.get("server_version", (0,)) + ] + for row in rows: + for col in cols: + self.check_int_list(row, col) + + def test_pg_constraint(self, connection): + insp = inspect(connection) + + pgc_oid = insp.get_table_oid("sample_table") + cols = [ + col + for col in pg_catalog.pg_constraint.c + if testing.db.dialect.server_version_info + >= col.info.get("server_version", (0,)) + ] + stmt = sa.select(*cols).filter_by(conrelid=pgc_oid) + rows = connection.execute(stmt).mappings().all() + is_true(len(rows) > 0) + for row in rows: + self.check_int_list(row, "conkey") |