summaryrefslogtreecommitdiff
path: root/test/dialect/postgresql/test_reflection.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/dialect/postgresql/test_reflection.py')
-rw-r--r--test/dialect/postgresql/test_reflection.py441
1 files changed, 343 insertions, 98 deletions
diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py
index cbb1809e4..00e5dc5b9 100644
--- a/test/dialect/postgresql/test_reflection.py
+++ b/test/dialect/postgresql/test_reflection.py
@@ -27,17 +27,21 @@ from sqlalchemy.dialects.postgresql import base as postgresql
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.dialects.postgresql import INTERVAL
+from sqlalchemy.dialects.postgresql import pg_catalog
from sqlalchemy.dialects.postgresql import TSRANGE
+from sqlalchemy.engine import ObjectKind
+from sqlalchemy.engine import ObjectScope
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
-from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_warns
from sqlalchemy.testing.assertions import AssertsExecutionResults
from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.assertions import expect_raises
from sqlalchemy.testing.assertions import is_
+from sqlalchemy.testing.assertions import is_false
from sqlalchemy.testing.assertions import is_true
@@ -231,17 +235,36 @@ class MaterializedViewReflectionTest(
connection.execute(target.insert(), {"id": 89, "data": "d1"})
materialized_view = sa.DDL(
- "CREATE MATERIALIZED VIEW test_mview AS " "SELECT * FROM testtable"
+ "CREATE MATERIALIZED VIEW test_mview AS SELECT * FROM testtable"
)
plain_view = sa.DDL(
- "CREATE VIEW test_regview AS " "SELECT * FROM testtable"
+ "CREATE VIEW test_regview AS SELECT data FROM testtable"
)
sa.event.listen(testtable, "after_create", plain_view)
sa.event.listen(testtable, "after_create", materialized_view)
sa.event.listen(
testtable,
+ "after_create",
+ sa.DDL("COMMENT ON VIEW test_regview IS 'regular view comment'"),
+ )
+ sa.event.listen(
+ testtable,
+ "after_create",
+ sa.DDL(
+ "COMMENT ON MATERIALIZED VIEW test_mview "
+ "IS 'materialized view comment'"
+ ),
+ )
+ sa.event.listen(
+ testtable,
+ "after_create",
+ sa.DDL("CREATE INDEX mat_index ON test_mview(data DESC)"),
+ )
+
+ sa.event.listen(
+ testtable,
"before_drop",
sa.DDL("DROP MATERIALIZED VIEW test_mview"),
)
@@ -249,6 +272,12 @@ class MaterializedViewReflectionTest(
testtable, "before_drop", sa.DDL("DROP VIEW test_regview")
)
+ def test_has_type(self, connection):
+ insp = inspect(connection)
+ is_true(insp.has_type("test_mview"))
+ is_true(insp.has_type("test_regview"))
+ is_true(insp.has_type("testtable"))
+
def test_mview_is_reflected(self, connection):
metadata = MetaData()
table = Table("test_mview", metadata, autoload_with=connection)
@@ -265,49 +294,99 @@ class MaterializedViewReflectionTest(
def test_get_view_names(self, inspect_fixture):
insp, conn = inspect_fixture
- eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
+ eq_(set(insp.get_view_names()), set(["test_regview"]))
- def test_get_view_names_plain(self, connection):
+ def test_get_materialized_view_names(self, inspect_fixture):
+ insp, conn = inspect_fixture
+ eq_(set(insp.get_materialized_view_names()), set(["test_mview"]))
+
+ def test_get_view_names_reflection_cache_ok(self, connection):
insp = inspect(connection)
+ eq_(set(insp.get_view_names()), set(["test_regview"]))
eq_(
- set(insp.get_view_names(include=("plain",))), set(["test_regview"])
+ set(insp.get_materialized_view_names()),
+ set(["test_mview"]),
+ )
+ eq_(
+ set(insp.get_view_names()).union(
+ insp.get_materialized_view_names()
+ ),
+ set(["test_regview", "test_mview"]),
)
- def test_get_view_names_plain_string(self, connection):
+ def test_get_view_definition(self, connection):
insp = inspect(connection)
- eq_(set(insp.get_view_names(include="plain")), set(["test_regview"]))
- def test_get_view_names_materialized(self, connection):
- insp = inspect(connection)
+ def normalize(definition):
+ return re.sub(r"[\n\t ]+", " ", definition.strip())
+
eq_(
- set(insp.get_view_names(include=("materialized",))),
- set(["test_mview"]),
+ normalize(insp.get_view_definition("test_mview")),
+ "SELECT testtable.id, testtable.data FROM testtable;",
+ )
+ eq_(
+ normalize(insp.get_view_definition("test_regview")),
+ "SELECT testtable.data FROM testtable;",
)
- def test_get_view_names_reflection_cache_ok(self, connection):
+ def test_get_view_comment(self, connection):
insp = inspect(connection)
eq_(
- set(insp.get_view_names(include=("plain",))), set(["test_regview"])
+ insp.get_table_comment("test_regview"),
+ {"text": "regular view comment"},
)
eq_(
- set(insp.get_view_names(include=("materialized",))),
- set(["test_mview"]),
+ insp.get_table_comment("test_mview"),
+ {"text": "materialized view comment"},
)
- eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_empty(self, connection):
+ def test_get_multi_view_comment(self, connection):
insp = inspect(connection)
- assert_raises(ValueError, insp.get_view_names, include=())
+ eq_(
+ insp.get_multi_table_comment(),
+ {(None, "testtable"): {"text": None}},
+ )
+ plain = {(None, "test_regview"): {"text": "regular view comment"}}
+ mat = {(None, "test_mview"): {"text": "materialized view comment"}}
+ eq_(insp.get_multi_table_comment(kind=ObjectKind.VIEW), plain)
+ eq_(
+ insp.get_multi_table_comment(kind=ObjectKind.MATERIALIZED_VIEW),
+ mat,
+ )
+ eq_(
+ insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW),
+ {**plain, **mat},
+ )
+ eq_(
+ insp.get_multi_table_comment(
+ kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY
+ ),
+ {},
+ )
- def test_get_view_definition(self, connection):
+ def test_get_multi_view_indexes(self, connection):
insp = inspect(connection)
+ eq_(insp.get_multi_indexes(), {(None, "testtable"): []})
+
+ exp = {
+ "name": "mat_index",
+ "unique": False,
+ "column_names": ["data"],
+ "column_sorting": {"data": ("desc",)},
+ }
+ if connection.dialect.server_version_info >= (11, 0):
+ exp["include_columns"] = []
+ exp["dialect_options"] = {"postgresql_include": []}
+ plain = {(None, "test_regview"): []}
+ mat = {(None, "test_mview"): [exp]}
+ eq_(insp.get_multi_indexes(kind=ObjectKind.VIEW), plain)
+ eq_(insp.get_multi_indexes(kind=ObjectKind.MATERIALIZED_VIEW), mat)
+ eq_(insp.get_multi_indexes(kind=ObjectKind.ANY_VIEW), {**plain, **mat})
eq_(
- re.sub(
- r"[\n\t ]+",
- " ",
- insp.get_view_definition("test_mview").strip(),
+ insp.get_multi_indexes(
+ kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY
),
- "SELECT testtable.id, testtable.data FROM testtable;",
+ {},
)
@@ -993,9 +1072,9 @@ class ReflectionTest(
go,
[
"Skipped unsupported reflection of "
- "expression-based index idx1",
+ "expression-based index idx1 of table party",
"Skipped unsupported reflection of "
- "expression-based index idx3",
+ "expression-based index idx3 of table party",
],
)
@@ -1016,7 +1095,7 @@ class ReflectionTest(
metadata.create_all(connection)
- ind = connection.dialect.get_indexes(connection, t1, None)
+ ind = connection.dialect.get_indexes(connection, t1.name, None)
partial_definitions = []
for ix in ind:
@@ -1337,6 +1416,9 @@ class ReflectionTest(
}
],
)
+ is_true(inspector.has_type("mood", "test_schema"))
+ is_true(inspector.has_type("mood", "*"))
+ is_false(inspector.has_type("mood"))
def test_inspect_enums(self, metadata, inspect_fixture):
@@ -1345,30 +1427,49 @@ class ReflectionTest(
enum_type = postgresql.ENUM(
"cat", "dog", "rat", name="pet", metadata=metadata
)
+ enum_type.create(conn)
+ conn.commit()
- with conn.begin():
- enum_type.create(conn)
-
- eq_(
- inspector.get_enums(),
- [
- {
- "visible": True,
- "labels": ["cat", "dog", "rat"],
- "name": "pet",
- "schema": "public",
- }
- ],
- )
-
- def test_get_table_oid(self, metadata, inspect_fixture):
-
- inspector, conn = inspect_fixture
+ res = [
+ {
+ "visible": True,
+ "labels": ["cat", "dog", "rat"],
+ "name": "pet",
+ "schema": "public",
+ }
+ ]
+ eq_(inspector.get_enums(), res)
+ is_true(inspector.has_type("pet", "*"))
+ is_true(inspector.has_type("pet"))
+ is_false(inspector.has_type("pet", "test_schema"))
+
+ enum_type.drop(conn)
+ conn.commit()
+ eq_(inspector.get_enums(), res)
+ is_true(inspector.has_type("pet"))
+ inspector.clear_cache()
+ eq_(inspector.get_enums(), [])
+ is_false(inspector.has_type("pet"))
+
+ def test_get_table_oid(self, metadata, connection):
+ Table("t1", metadata, Column("col", Integer))
+ Table("t1", metadata, Column("col", Integer), schema="test_schema")
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ oid = insp.get_table_oid("t1")
+ oid_schema = insp.get_table_oid("t1", schema="test_schema")
+ is_true(isinstance(oid, int))
+ is_true(isinstance(oid_schema, int))
+ is_true(oid != oid_schema)
- with conn.begin():
- Table("some_table", metadata, Column("q", Integer)).create(conn)
+ with expect_raises(exc.NoSuchTableError):
+ insp.get_table_oid("does_not_exist")
- assert inspector.get_table_oid("some_table") is not None
+ metadata.tables["t1"].drop(connection)
+ eq_(insp.get_table_oid("t1"), oid)
+ insp.clear_cache()
+ with expect_raises(exc.NoSuchTableError):
+ insp.get_table_oid("t1")
def test_inspect_enums_case_sensitive(self, metadata, connection):
sa.event.listen(
@@ -1707,77 +1808,146 @@ class ReflectionTest(
)
def test_reflect_check_warning(self):
- rows = [("some name", "NOTCHECK foobar")]
+ rows = [("foo", "some name", "NOTCHECK foobar")]
conn = mock.Mock(
execute=lambda *arg, **kw: mock.MagicMock(
fetchall=lambda: rows, __iter__=lambda self: iter(rows)
)
)
- with mock.patch.object(
- testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1
+ with testing.expect_warnings(
+ "Could not parse CHECK constraint text: 'NOTCHECK foobar'"
):
- with testing.expect_warnings(
- "Could not parse CHECK constraint text: 'NOTCHECK foobar'"
- ):
- testing.db.dialect.get_check_constraints(conn, "foo")
+ testing.db.dialect.get_check_constraints(conn, "foo")
def test_reflect_extra_newlines(self):
rows = [
- ("some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"),
- ("some other name", "CHECK ((b\nIS\nNOT\nNULL))"),
- ("some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"),
- ("some name", "CHECK (c != 'hi\nim a name\n')"),
+ ("foo", "some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"),
+ ("foo", "some other name", "CHECK ((b\nIS\nNOT\nNULL))"),
+ ("foo", "some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"),
+ ("foo", "some name", "CHECK (c != 'hi\nim a name\n')"),
]
conn = mock.Mock(
execute=lambda *arg, **kw: mock.MagicMock(
fetchall=lambda: rows, __iter__=lambda self: iter(rows)
)
)
- with mock.patch.object(
- testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1
- ):
- check_constraints = testing.db.dialect.get_check_constraints(
- conn, "foo"
- )
- eq_(
- check_constraints,
- [
- {
- "name": "some name",
- "sqltext": "a \nIS\n NOT\n\n NULL\n",
- },
- {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"},
- {
- "name": "some CRLF name",
- "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL",
- },
- {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"},
- ],
- )
+ check_constraints = testing.db.dialect.get_check_constraints(
+ conn, "foo"
+ )
+ eq_(
+ check_constraints,
+ [
+ {
+ "name": "some name",
+ "sqltext": "a \nIS\n NOT\n\n NULL\n",
+ },
+ {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"},
+ {
+ "name": "some CRLF name",
+ "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL",
+ },
+ {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"},
+ ],
+ )
def test_reflect_with_not_valid_check_constraint(self):
- rows = [("some name", "CHECK ((a IS NOT NULL)) NOT VALID")]
+ rows = [("foo", "some name", "CHECK ((a IS NOT NULL)) NOT VALID")]
conn = mock.Mock(
execute=lambda *arg, **kw: mock.MagicMock(
fetchall=lambda: rows, __iter__=lambda self: iter(rows)
)
)
- with mock.patch.object(
- testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1
- ):
- check_constraints = testing.db.dialect.get_check_constraints(
- conn, "foo"
+ check_constraints = testing.db.dialect.get_check_constraints(
+ conn, "foo"
+ )
+ eq_(
+ check_constraints,
+ [
+ {
+ "name": "some name",
+ "sqltext": "a IS NOT NULL",
+ "dialect_options": {"not_valid": True},
+ }
+ ],
+ )
+
+ def _apply_stm(self, connection, use_map):
+ if use_map:
+ return connection.execution_options(
+ schema_translate_map={
+ None: "foo",
+ testing.config.test_schema: "bar",
+ }
)
- eq_(
- check_constraints,
- [
- {
- "name": "some name",
- "sqltext": "a IS NOT NULL",
- "dialect_options": {"not_valid": True},
- }
- ],
+ else:
+ return connection
+
+ @testing.combinations(True, False, argnames="use_map")
+ @testing.combinations(True, False, argnames="schema")
+ def test_schema_translate_map(self, metadata, connection, use_map, schema):
+ schema = testing.config.test_schema if schema else None
+ Table(
+ "foo",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("a", Integer, index=True),
+ Column(
+ "b",
+ ForeignKey(f"{schema}.foo.id" if schema else "foo.id"),
+ unique=True,
+ ),
+ CheckConstraint("a>10", name="foo_check"),
+ comment="comm",
+ schema=schema,
+ )
+ metadata.create_all(connection)
+ if use_map:
+ connection = connection.execution_options(
+ schema_translate_map={
+ None: "foo",
+ testing.config.test_schema: "bar",
+ }
)
+ insp = inspect(connection)
+ eq_(
+ [c["name"] for c in insp.get_columns("foo", schema=schema)],
+ ["id", "a", "b"],
+ )
+ eq_(
+ [
+ i["column_names"]
+ for i in insp.get_indexes("foo", schema=schema)
+ ],
+ [["b"], ["a"]],
+ )
+ eq_(
+ insp.get_pk_constraint("foo", schema=schema)[
+ "constrained_columns"
+ ],
+ ["id"],
+ )
+ eq_(insp.get_table_comment("foo", schema=schema), {"text": "comm"})
+ eq_(
+ [
+ f["constrained_columns"]
+ for f in insp.get_foreign_keys("foo", schema=schema)
+ ],
+ [["b"]],
+ )
+ eq_(
+ [
+ c["name"]
+ for c in insp.get_check_constraints("foo", schema=schema)
+ ],
+ ["foo_check"],
+ )
+ eq_(
+ [
+ u["column_names"]
+ for u in insp.get_unique_constraints("foo", schema=schema)
+ ],
+ [["b"]],
+ )
class CustomTypeReflectionTest(fixtures.TestBase):
@@ -1804,9 +1974,23 @@ class CustomTypeReflectionTest(fixtures.TestBase):
("my_custom_type(ARG1)", ("ARG1", None)),
("my_custom_type(ARG1, ARG2)", ("ARG1", "ARG2")),
]:
- column_info = dialect._get_column_info(
- "colname", sch, None, False, {}, {}, "public", None, "", None
+ row_dict = {
+ "name": "colname",
+ "table_name": "tblname",
+ "format_type": sch,
+ "default": None,
+ "not_null": False,
+ "comment": None,
+ "generated": "",
+ "identity_options": None,
+ }
+ column_info = dialect._get_columns_info(
+ [row_dict], {}, {}, "public"
)
+ assert ("public", "tblname") in column_info
+ column_info = column_info[("public", "tblname")]
+ assert len(column_info) == 1
+ column_info = column_info[0]
assert isinstance(column_info["type"], self.CustomType)
eq_(column_info["type"].arg1, args[0])
eq_(column_info["type"].arg2, args[1])
@@ -1951,3 +2135,64 @@ class IdentityReflectionTest(fixtures.TablesTest):
exp = default.copy()
exp.update(maxvalue=2**15 - 1)
eq_(col["identity"], exp)
+
+
+class TestReflectDifficultColTypes(fixtures.TablesTest):
+ __only_on__ = "postgresql"
+ __backend__ = True
+
+ def define_tables(metadata):
+ Table(
+ "sample_table",
+ metadata,
+ Column("c1", Integer, primary_key=True),
+ Column("c2", Integer, unique=True),
+ Column("c3", Integer),
+ Index("sample_table_index", "c2", "c3"),
+ )
+
+ def check_int_list(self, row, key):
+ value = row[key]
+ is_true(isinstance(value, list))
+ is_true(len(value) > 0)
+ is_true(all(isinstance(v, int) for v in value))
+
+ def test_pg_index(self, connection):
+ insp = inspect(connection)
+
+ pgc_oid = insp.get_table_oid("sample_table")
+ cols = [
+ col
+ for col in pg_catalog.pg_index.c
+ if testing.db.dialect.server_version_info
+ >= col.info.get("server_version", (0,))
+ ]
+
+ stmt = sa.select(*cols).filter_by(indrelid=pgc_oid)
+ rows = connection.execute(stmt).mappings().all()
+ is_true(len(rows) > 0)
+ cols = [
+ col
+ for col in ["indkey", "indoption", "indclass", "indcollation"]
+ if testing.db.dialect.server_version_info
+ >= pg_catalog.pg_index.c[col].info.get("server_version", (0,))
+ ]
+ for row in rows:
+ for col in cols:
+ self.check_int_list(row, col)
+
+ def test_pg_constraint(self, connection):
+ insp = inspect(connection)
+
+ pgc_oid = insp.get_table_oid("sample_table")
+ cols = [
+ col
+ for col in pg_catalog.pg_constraint.c
+ if testing.db.dialect.server_version_info
+ >= col.info.get("server_version", (0,))
+ ]
+ stmt = sa.select(*cols).filter_by(conrelid=pgc_oid)
+ rows = connection.execute(stmt).mappings().all()
+ is_true(len(rows) > 0)
+ for row in rows:
+ self.check_int_list(row, "conkey")