summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/aaa_profiling/test_memusage.py22
-rw-r--r--test/dialect/postgresql/test_compiler.py76
-rw-r--r--test/dialect/postgresql/test_reflection.py126
-rw-r--r--test/dialect/postgresql/test_types.py371
-rw-r--r--test/sql/test_types.py45
5 files changed, 507 insertions, 133 deletions
diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py
index e90c428f7..dc6dcf060 100644
--- a/test/aaa_profiling/test_memusage.py
+++ b/test/aaa_profiling/test_memusage.py
@@ -312,20 +312,22 @@ class MemUsageTest(EnsureZeroed):
eng = engines.testing_engine()
for args in (
- (types.Integer,),
- (types.String,),
- (types.PickleType,),
- (types.Enum, "a", "b", "c"),
- (sqlite.DATETIME,),
- (postgresql.ENUM, "a", "b", "c"),
- (types.Interval,),
- (postgresql.INTERVAL,),
- (mysql.VARCHAR,),
+ (types.Integer, {}),
+ (types.String, {}),
+ (types.PickleType, {}),
+ (types.Enum, "a", "b", "c", {}),
+ (sqlite.DATETIME, {}),
+ (postgresql.ENUM, "a", "b", "c", {"name": "pgenum"}),
+ (types.Interval, {}),
+ (postgresql.INTERVAL, {}),
+ (mysql.VARCHAR, {}),
):
@profile_memory()
def go():
- type_ = args[0](*args[1:])
+ kwargs = args[-1]
+ posargs = args[1:-1]
+ type_ = args[0](*posargs, **kwargs)
bp = type_._cached_bind_processor(eng.dialect)
rp = type_._cached_result_processor(eng.dialect, 0)
bp, rp # strong reference
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 25550afe1..9be76130d 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -38,6 +38,7 @@ from sqlalchemy.dialects.postgresql import aggregate_order_by
from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
+from sqlalchemy.dialects.postgresql import DOMAIN
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.dialects.postgresql import TSRANGE
@@ -270,7 +271,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
render_schema_translate=True,
)
- def test_create_type_schema_translate(self):
+ def test_create_enum_schema_translate(self):
e1 = Enum("x", "y", "z", name="somename")
e2 = Enum("x", "y", "z", name="somename", schema="someschema")
schema_translate_map = {None: "foo", "someschema": "bar"}
@@ -289,6 +290,79 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
render_schema_translate=True,
)
+ def test_domain(self):
+ self.assert_compile(
+ postgresql.CreateDomainType(
+ DOMAIN(
+ "x",
+ Integer,
+ default=text("11"),
+ not_null=True,
+ check="VALUE < 0",
+ )
+ ),
+ "CREATE DOMAIN x AS INTEGER DEFAULT 11 NOT NULL CHECK (VALUE < 0)",
+ )
+ self.assert_compile(
+ postgresql.CreateDomainType(
+ DOMAIN(
+ "sOmEnAmE",
+ Text,
+ collation="utf8",
+ constraint_name="a constraint",
+ not_null=True,
+ )
+ ),
+ 'CREATE DOMAIN "sOmEnAmE" AS TEXT COLLATE utf8 CONSTRAINT '
+ '"a constraint" NOT NULL',
+ )
+ self.assert_compile(
+ postgresql.CreateDomainType(
+ DOMAIN(
+ "foo",
+ Text,
+ collation="utf8",
+ default="foobar",
+ constraint_name="no_bar",
+ not_null=True,
+ check="VALUE != 'bar'",
+ )
+ ),
+ "CREATE DOMAIN foo AS TEXT COLLATE utf8 DEFAULT 'foobar' "
+ "CONSTRAINT no_bar NOT NULL CHECK (VALUE != 'bar')",
+ )
+
+ def test_cast_domain_schema(self):
+ """test #6739"""
+ d1 = DOMAIN("somename", Integer)
+ d2 = DOMAIN("somename", Integer, schema="someschema")
+
+ stmt = select(cast(column("foo"), d1), cast(column("bar"), d2))
+ self.assert_compile(
+ stmt,
+ "SELECT CAST(foo AS somename) AS foo, "
+ "CAST(bar AS someschema.somename) AS bar",
+ )
+
+ def test_create_domain_schema_translate(self):
+ d1 = DOMAIN("somename", Integer)
+ d2 = DOMAIN("somename", Integer, schema="someschema")
+ schema_translate_map = {None: "foo", "someschema": "bar"}
+
+ self.assert_compile(
+ postgresql.CreateDomainType(d1),
+ "CREATE DOMAIN foo.somename AS INTEGER ",
+ schema_translate_map=schema_translate_map,
+ render_schema_translate=True,
+ )
+
+ self.assert_compile(
+ postgresql.CreateDomainType(d2),
+ "CREATE DOMAIN bar.somename AS INTEGER ",
+ schema_translate_map=schema_translate_map,
+ render_schema_translate=True,
+ )
+
def test_create_table_with_schema_type_schema_translate(self):
e1 = Enum("x", "y", "z", name="somename")
e2 = Enum("x", "y", "z", name="somename", schema="someschema")
diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py
index 21b4149bc..99bc14d78 100644
--- a/test/dialect/postgresql/test_reflection.py
+++ b/test/dialect/postgresql/test_reflection.py
@@ -410,6 +410,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
"CREATE DOMAIN nullable_domain AS TEXT CHECK "
"(VALUE IN('FOO', 'BAR'))",
"CREATE DOMAIN not_nullable_domain AS TEXT NOT NULL",
+ "CREATE DOMAIN my_int AS int CONSTRAINT b_my_int_one CHECK "
+ "(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) "
+ "CHECK(VALUE != 22)",
]:
try:
con.exec_driver_sql(ddl)
@@ -468,6 +471,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
con.exec_driver_sql("DROP TABLE nullable_domain_test")
con.exec_driver_sql("DROP DOMAIN nullable_domain")
con.exec_driver_sql("DROP DOMAIN not_nullable_domain")
+ con.exec_driver_sql("DROP DOMAIN my_int")
def test_table_is_reflected(self, connection):
metadata = MetaData()
@@ -579,6 +583,122 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
finally:
base.PGDialect.ischema_names = ischema_names
+ @property
+ def all_domains(self):
+ return {
+ "public": [
+ {
+ "visible": True,
+ "name": "arraydomain",
+ "schema": "public",
+ "nullable": True,
+ "type": "integer[]",
+ "default": None,
+ "constraints": [],
+ },
+ {
+ "visible": True,
+ "name": "enumdomain",
+ "schema": "public",
+ "nullable": True,
+ "type": "testtype",
+ "default": None,
+ "constraints": [],
+ },
+ {
+ "visible": True,
+ "name": "my_int",
+ "schema": "public",
+ "nullable": True,
+ "type": "integer",
+ "default": None,
+ "constraints": [
+ {"check": "VALUE < 42", "name": "a_my_int_two"},
+ {"check": "VALUE > 1", "name": "b_my_int_one"},
+ # autogenerated name by pg
+ {"check": "VALUE <> 22", "name": "my_int_check"},
+ ],
+ },
+ {
+ "visible": True,
+ "name": "not_nullable_domain",
+ "schema": "public",
+ "nullable": False,
+ "type": "text",
+ "default": None,
+ "constraints": [],
+ },
+ {
+ "visible": True,
+ "name": "nullable_domain",
+ "schema": "public",
+ "nullable": True,
+ "type": "text",
+ "default": None,
+ "constraints": [
+ {
+ "check": "VALUE = ANY (ARRAY['FOO'::text, "
+ "'BAR'::text])",
+ # autogenerated name by pg
+ "name": "nullable_domain_check",
+ }
+ ],
+ },
+ {
+ "visible": True,
+ "name": "testdomain",
+ "schema": "public",
+ "nullable": False,
+ "type": "integer",
+ "default": "42",
+ "constraints": [],
+ },
+ ],
+ "test_schema": [
+ {
+ "visible": False,
+ "name": "testdomain",
+ "schema": "test_schema",
+ "nullable": True,
+ "type": "integer",
+ "default": "0",
+ "constraints": [],
+ }
+ ],
+ "SomeSchema": [
+ {
+ "visible": False,
+ "name": "Quoted.Domain",
+ "schema": "SomeSchema",
+ "nullable": True,
+ "type": "integer",
+ "default": "0",
+ "constraints": [],
+ }
+ ],
+ }
+
+ def test_inspect_domains(self, connection):
+ inspector = inspect(connection)
+ eq_(inspector.get_domains(), self.all_domains["public"])
+
+ def test_inspect_domains_schema(self, connection):
+ inspector = inspect(connection)
+ eq_(
+ inspector.get_domains("test_schema"),
+ self.all_domains["test_schema"],
+ )
+ eq_(
+ inspector.get_domains("SomeSchema"), self.all_domains["SomeSchema"]
+ )
+
+ def test_inspect_domains_star(self, connection):
+ inspector = inspect(connection)
+ all_ = [d for dl in self.all_domains.values() for d in dl]
+ all_ += inspector.get_domains("information_schema")
+ exp = sorted(all_, key=lambda d: (d["schema"], d["name"]))
+ eq_(inspector.get_domains("*"), exp)
+
class ReflectionTest(
ReflectionFixtures, AssertsCompiledSQL, fixtures.TestBase
@@ -1800,10 +1920,10 @@ class ReflectionTest(
eq_(
check_constraints,
{
- "cc1": "(a > 1) AND (a < 5)",
- "cc2": "(a = 1) OR ((a > 2) AND (a < 5))",
+ "cc1": "a > 1 AND a < 5",
+ "cc2": "a = 1 OR a > 2 AND a < 5",
"cc3": "is_positive(a)",
- "cc4": "(b)::text <> 'hi\nim a name \nyup\n'::text",
+ "cc4": "b::text <> 'hi\nim a name \nyup\n'::text",
},
)
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 79f029391..5c3935d44 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -38,12 +38,15 @@ from sqlalchemy import util
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects.postgresql import DATERANGE
+from sqlalchemy.dialects.postgresql import DOMAIN
+from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy.dialects.postgresql import HSTORE
from sqlalchemy.dialects.postgresql import hstore
from sqlalchemy.dialects.postgresql import INT4RANGE
from sqlalchemy.dialects.postgresql import INT8RANGE
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import NamedType
from sqlalchemy.dialects.postgresql import NUMRANGE
from sqlalchemy.dialects.postgresql import TSRANGE
from sqlalchemy.dialects.postgresql import TSTZRANGE
@@ -161,7 +164,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
eq_(row, ([5], [5], [6], [7], [decimal.Decimal("6.4")]))
-class EnumTest(fixtures.TestBase, AssertsExecutionResults):
+class NamedTypeTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
__only_on__ = "postgresql > 8.3"
@@ -173,16 +176,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
"the native_enum flag does not apply to the "
"sqlalchemy.dialects.postgresql.ENUM datatype;"
):
- e1 = postgresql.ENUM("a", "b", "c", native_enum=False)
+ e1 = postgresql.ENUM(
+ "a", "b", "c", name="pgenum", native_enum=False
+ )
- e2 = postgresql.ENUM("a", "b", "c", native_enum=True)
- e3 = postgresql.ENUM("a", "b", "c")
+ e2 = postgresql.ENUM("a", "b", "c", name="pgenum", native_enum=True)
+ e3 = postgresql.ENUM("a", "b", "c", name="pgenum")
is_(e1.native_enum, True)
is_(e2.native_enum, True)
is_(e3.native_enum, True)
- def test_create_table(self, metadata, connection):
+ def test_enum_create_table(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"table",
@@ -202,50 +207,147 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
[(1, "two"), (2, "three"), (3, "three")],
)
+ def test_domain_create_table(self, metadata, connection):
+ metadata = self.metadata
+ Email = DOMAIN(
+ name="email",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ )
+ PosInt = DOMAIN(
+ name="pos_int",
+ data_type=Integer,
+ not_null=True,
+ check=r"VALUE > 0",
+ )
+ t1 = Table(
+ "table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("email", Email),
+ Column("number", PosInt),
+ )
+ t1.create(connection)
+ t1.create(connection, checkfirst=True) # check the create
+ connection.execute(
+ t1.insert(), {"email": "test@example.com", "number": 42}
+ )
+ connection.execute(t1.insert(), {"email": "a@b.c", "number": 1})
+ connection.execute(
+ t1.insert(), {"email": "example@gmail.co.uk", "number": 99}
+ )
+ eq_(
+ connection.execute(t1.select().order_by(t1.c.id)).fetchall(),
+ [
+ (1, "test@example.com", 42),
+ (2, "a@b.c", 1),
+ (3, "example@gmail.co.uk", 99),
+ ],
+ )
+
+ @testing.combinations(
+ (ENUM("one", "two", "three", name="mytype"), "get_enums"),
+ (
+ DOMAIN(
+ name="mytype",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ ),
+ "get_domains",
+ ),
+ argnames="datatype, method",
+ )
+ def test_drops_on_table(
+ self, connection, metadata, datatype: "NamedType", method
+ ):
+ table = Table("e1", metadata, Column("e1", datatype))
+
+ table.create(connection)
+ table.drop(connection)
+
+ assert "mytype" not in [
+ e["name"] for e in getattr(inspect(connection), method)()
+ ]
+ table.create(connection)
+ assert "mytype" in [
+ e["name"] for e in getattr(inspect(connection), method)()
+ ]
+ table.drop(connection)
+ assert "mytype" not in [
+ e["name"] for e in getattr(inspect(connection), method)()
+ ]
+
+ @testing.combinations(
+ (
+ lambda symbol_name: ENUM(
+ "one", "two", "three", name="schema_mytype", schema=symbol_name
+ ),
+ ["two", "three", "three"],
+ "get_enums",
+ ),
+ (
+ lambda symbol_name: DOMAIN(
+ name="schema_mytype",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ schema=symbol_name,
+ ),
+ ["test@example.com", "a@b.c", "example@gmail.co.uk"],
+ "get_domains",
+ ),
+ argnames="datatype,data,method",
+ )
@testing.combinations(None, "foo", argnames="symbol_name")
- def test_create_table_schema_translate_map(self, connection, symbol_name):
+ def test_create_table_schema_translate_map(
+ self, connection, symbol_name, datatype, data, method
+ ):
# note we can't use the fixture here because it will not drop
# from the correct schema
metadata = MetaData()
+ dt = datatype(symbol_name)
+
t1 = Table(
"table",
metadata,
Column("id", Integer, primary_key=True),
- Column(
- "value",
- Enum(
- "one",
- "two",
- "three",
- name="schema_enum",
- schema=symbol_name,
- ),
- ),
+ Column("value", dt),
schema=symbol_name,
)
conn = connection.execution_options(
schema_translate_map={symbol_name: testing.config.test_schema}
)
t1.create(conn)
- assert "schema_enum" in [
+ assert "schema_mytype" in [
e["name"]
- for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ for e in getattr(inspect(conn), method)(
+ schema=testing.config.test_schema
+ )
]
t1.create(conn, checkfirst=True)
- conn.execute(t1.insert(), dict(value="two"))
- conn.execute(t1.insert(), dict(value="three"))
- conn.execute(t1.insert(), dict(value="three"))
+ conn.execute(
+ t1.insert(),
+ dict(value=data[0]),
+ )
+ conn.execute(t1.insert(), dict(value=data[1]))
+ conn.execute(t1.insert(), dict(value=data[2]))
eq_(
conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
- [(1, "two"), (2, "three"), (3, "three")],
+ [
+ (1, data[0]),
+ (2, data[1]),
+ (3, data[2]),
+ ],
)
t1.drop(conn)
- assert "schema_enum" not in [
+
+ assert "schema_mytype" not in [
e["name"]
- for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ for e in getattr(inspect(conn), method)(
+ schema=testing.config.test_schema
+ )
]
t1.drop(conn, checkfirst=True)
@@ -256,40 +358,48 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
("override_metadata_schema",),
argnames="test_case",
)
+ @testing.combinations("enum", "domain", argnames="datatype")
@testing.requires.schemas
- def test_schema_inheritance(self, test_case, metadata, connection):
+ def test_schema_inheritance(
+ self, test_case, metadata, connection, datatype
+ ):
"""test #6373"""
metadata.schema = testing.config.test_schema
+ def make_type(**kw):
+ if datatype == "enum":
+ return Enum("four", "five", "six", name="mytype", **kw)
+ elif datatype == "domain":
+ return DOMAIN(
+ name="mytype",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ **kw,
+ )
+ else:
+ assert False
+
if test_case == "metadata_schema_only":
- enum = Enum(
- "four", "five", "six", metadata=metadata, name="myenum"
- )
+ enum = make_type(metadata=metadata)
assert_schema = testing.config.test_schema
elif test_case == "override_metadata_schema":
- enum = Enum(
- "four",
- "five",
- "six",
+ enum = make_type(
metadata=metadata,
schema=testing.config.test_schema_2,
- name="myenum",
)
assert_schema = testing.config.test_schema_2
elif test_case == "inherit_table_schema":
- enum = Enum(
- "four",
- "five",
- "six",
+ enum = make_type(
metadata=metadata,
inherit_schema=True,
- name="myenum",
)
assert_schema = testing.config.test_schema_2
elif test_case == "local_schema":
- enum = Enum("four", "five", "six", name="myenum")
+ enum = make_type()
assert_schema = testing.config.db.dialect.default_schema_name
+ else:
+ assert False
Table(
"t",
@@ -300,27 +410,62 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
metadata.create_all(connection)
- eq_(
- inspect(connection).get_enums(schema=assert_schema),
- [
- {
- "labels": ["four", "five", "six"],
- "name": "myenum",
- "schema": assert_schema,
- "visible": assert_schema
- == testing.config.db.dialect.default_schema_name,
- }
- ],
- )
+ if datatype == "enum":
+ eq_(
+ inspect(connection).get_enums(schema=assert_schema),
+ [
+ {
+ "labels": ["four", "five", "six"],
+ "name": "mytype",
+ "schema": assert_schema,
+ "visible": assert_schema
+ == testing.config.db.dialect.default_schema_name,
+ }
+ ],
+ )
+ elif datatype == "domain":
+
+ def_schame = testing.config.db.dialect.default_schema_name
+ eq_(
+ inspect(connection).get_domains(schema=assert_schema),
+ [
+ {
+ "name": "mytype",
+ "type": "text",
+ "nullable": True,
+ "default": None,
+ "schema": assert_schema,
+ "visible": assert_schema == def_schame,
+ "constraints": [
+ {
+ "name": "mytype_check",
+ "check": r"VALUE ~ '[^@]+@[^@]+\.[^@]+'::text",
+ }
+ ],
+ }
+ ],
+ )
+ else:
+ assert False
- def test_name_required(self, metadata, connection):
- etype = Enum("four", "five", "six", metadata=metadata)
- assert_raises(exc.CompileError, etype.create, connection)
+ @testing.combinations(
+ (ENUM("one", "two", "three", name=None)),
+ (
+ DOMAIN(
+ name=None,
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ ),
+ ),
+ argnames="datatype",
+ )
+ def test_name_required(self, metadata, connection, datatype):
+ assert_raises(exc.CompileError, datatype.create, connection)
assert_raises(
- exc.CompileError, etype.compile, dialect=connection.dialect
+ exc.CompileError, datatype.compile, dialect=connection.dialect
)
- def test_unicode_labels(self, connection, metadata):
+ def test_enum_unicode_labels(self, connection, metadata):
t1 = Table(
"table",
metadata,
@@ -426,22 +571,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
connection.execute(t1.insert(), {"bar": "Ü"})
eq_(connection.scalar(select(t1.c.bar)), "Ü")
- def test_disable_create(self, metadata, connection):
+ @testing.combinations(
+ (ENUM("one", "two", "three", name="mytype", create_type=False),),
+ (
+ DOMAIN(
+ name="mytype",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ create_type=False,
+ ),
+ ),
+ argnames="datatype",
+ )
+ def test_disable_create(self, metadata, connection, datatype):
metadata = self.metadata
- e1 = postgresql.ENUM(
- "one", "two", "three", name="myenum", create_type=False
- )
-
- t1 = Table("e1", metadata, Column("c1", e1))
+ t1 = Table("e1", metadata, Column("c1", datatype))
# table can be created separately
# without conflict
- e1.create(bind=connection)
+ datatype.create(bind=connection)
t1.create(connection)
t1.drop(connection)
- e1.drop(bind=connection)
+ datatype.drop(bind=connection)
- def test_dont_keep_checking(self, metadata, connection):
+ def test_enum_dont_keep_checking(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM("one", "two", "three", name="myenum")
@@ -486,7 +639,36 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
RegexSQL("DROP TYPE myenum", dialect="postgresql"),
)
- def test_generate_multiple(self, metadata, connection):
+ @testing.combinations(
+ (
+ Enum(
+ "one",
+ "two",
+ "three",
+ name="mytype",
+ ),
+ "get_enums",
+ ),
+ (
+ ENUM(
+ "one",
+ "two",
+ "three",
+ name="mytype",
+ ),
+ "get_enums",
+ ),
+ (
+ DOMAIN(
+ name="mytype",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ ),
+ "get_domains",
+ ),
+ argnames="datatype, method",
+ )
+ def test_generate_multiple(self, metadata, connection, datatype, method):
"""Test that the same enum twice only generates once
for the create_all() call, without using checkfirst.
@@ -494,15 +676,20 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
now handles this.
"""
- e1 = Enum("one", "two", "three", name="myenum")
- Table("e1", metadata, Column("c1", e1))
+ Table("e1", metadata, Column("c1", datatype))
- Table("e2", metadata, Column("c1", e1))
+ Table("e2", metadata, Column("c1", datatype))
metadata.create_all(connection, checkfirst=False)
+
+ assert "mytype" in [
+ e["name"] for e in getattr(inspect(connection), method)()
+ ]
+
metadata.drop_all(connection, checkfirst=False)
- assert "myenum" not in [
- e["name"] for e in inspect(connection).get_enums()
+
+ assert "mytype" not in [
+ e["name"] for e in getattr(inspect(connection), method)()
]
def test_generate_alone_on_metadata(self, connection, metadata):
@@ -571,23 +758,6 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
for e in inspect(connection).get_enums(schema="test_schema")
]
- def test_drops_on_table(self, connection, metadata):
-
- e1 = Enum("one", "two", "three", name="myenum")
- table = Table("e1", metadata, Column("c1", e1))
-
- table.create(connection)
- table.drop(connection)
- assert "myenum" not in [
- e["name"] for e in inspect(connection).get_enums()
- ]
- table.create(connection)
- assert "myenum" in [e["name"] for e in inspect(connection).get_enums()]
- table.drop(connection)
- assert "myenum" not in [
- e["name"] for e in inspect(connection).get_enums()
- ]
-
def test_create_drop_schema_translate_map(self, connection):
conn = connection.execution_options(
@@ -1445,15 +1615,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
array_agg,
)
- element_type = ENUM if with_enum else Integer
+ element = ENUM(name="pgenum") if with_enum else Integer()
+ element_type = type(element)
expr = (
array_agg(
aggregate_order_by(
- column("q", element_type), column("idx", Integer)
+ column("q", element), column("idx", Integer)
)
)
if using_aggregate_order_by
- else array_agg(column("q", element_type))
+ else array_agg(column("q", element))
)
is_(expr.type.__class__, postgresql.ARRAY)
is_(expr.type.item_type.__class__, element_type)
@@ -2081,10 +2252,13 @@ class ArrayRoundTripTest:
],
testing.requires.hstore,
),
- (postgresql.ENUM(AnEnum), enum_values),
+ (postgresql.ENUM(AnEnum, name="pgenum"), enum_values),
(sqltypes.Enum(AnEnum, native_enum=True), enum_values),
(sqltypes.Enum(AnEnum, native_enum=False), enum_values),
- (postgresql.ENUM(AnEnum, native_enum=True), enum_values),
+ (
+ postgresql.ENUM(AnEnum, name="pgenum", native_enum=True),
+ enum_values,
+ ),
(
make_difficult_enum(sqltypes.Enum, native=True),
difficult_enum_values,
@@ -2102,10 +2276,15 @@ class ArrayRoundTripTest:
if not exclude_empty_lists:
elements.extend(
[
- (postgresql.ENUM(AnEnum), empty_list),
+ (postgresql.ENUM(AnEnum, name="pgenum"), empty_list),
(sqltypes.Enum(AnEnum, native_enum=True), empty_list),
(sqltypes.Enum(AnEnum, native_enum=False), empty_list),
- (postgresql.ENUM(AnEnum, native_enum=True), empty_list),
+ (
+ postgresql.ENUM(
+ AnEnum, name="pgenum", native_enum=True
+ ),
+ empty_list,
+ ),
]
)
if not exclude_json:
@@ -2410,7 +2589,7 @@ class ArrayEnum(fixtures.TestBase):
),
Column(
"pyenum_col",
- array_cls(enum_cls(MyEnum)),
+ array_cls(enum_cls(MyEnum, name="pgenum")),
),
)
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 04aa4e000..623688b83 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -111,7 +111,11 @@ def _all_dialects():
def _types_for_mod(mod):
for key in dir(mod):
typ = getattr(mod, key)
- if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine):
+ if (
+ not isinstance(typ, type)
+ or not issubclass(typ, types.TypeEngine)
+ or typ.__dict__.get("__abstract__")
+ ):
continue
yield typ
@@ -143,6 +147,17 @@ def _all_types(omit_special_types=False):
yield typ
+def _get_instance(type_):
+ if issubclass(type_, ARRAY):
+ return type_(String)
+ elif hasattr(type_, "__test_init__"):
+ t1 = type_.__test_init__()
+ is_(isinstance(t1, type_), True)
+ return t1
+ else:
+ return type_()
+
+
class AdaptTest(fixtures.TestBase):
@testing.combinations(((t,) for t in _types_for_mod(types)), id_="n")
def test_uppercase_importable(self, typ):
@@ -240,11 +255,8 @@ class AdaptTest(fixtures.TestBase):
adapt() beyond their defaults.
"""
+ t1 = _get_instance(typ)
- if issubclass(typ, ARRAY):
- t1 = typ(String)
- else:
- t1 = typ()
for cls in target_adaptions:
if (is_down_adaption and issubclass(typ, sqltypes.Emulated)) or (
not is_down_adaption and issubclass(cls, sqltypes.Emulated)
@@ -301,19 +313,13 @@ class AdaptTest(fixtures.TestBase):
@testing.uses_deprecated()
@testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
def test_repr(self, typ):
- if issubclass(typ, ARRAY):
- t1 = typ(String)
- else:
- t1 = typ()
+ t1 = _get_instance(typ)
repr(t1)
@testing.uses_deprecated()
@testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
def test_str(self, typ):
- if issubclass(typ, ARRAY):
- t1 = typ(String)
- else:
- t1 = typ()
+ t1 = _get_instance(typ)
str(t1)
def test_str_third_party(self):
@@ -400,7 +406,7 @@ class AsGenericTest(fixtures.TestBase):
(pg.JSON(), sa.JSON()),
(pg.ARRAY(sa.String), sa.ARRAY(sa.String)),
(Enum("a", "b", "c"), Enum("a", "b", "c")),
- (pg.ENUM("a", "b", "c"), Enum("a", "b", "c")),
+ (pg.ENUM("a", "b", "c", name="pgenum"), Enum("a", "b", "c")),
(mysql.ENUM("a", "b", "c"), Enum("a", "b", "c")),
(pg.INTERVAL(precision=5), Interval(native=True, second_precision=5)),
(
@@ -419,11 +425,7 @@ class AsGenericTest(fixtures.TestBase):
]
)
def test_as_generic_all_types_heuristic(self, type_):
- if issubclass(type_, ARRAY):
- t1 = type_(String)
- else:
- t1 = type_()
-
+ t1 = _get_instance(type_)
try:
gentype = t1.as_generic()
except NotImplementedError:
@@ -445,10 +447,7 @@ class AsGenericTest(fixtures.TestBase):
]
)
def test_as_generic_all_types_custom(self, type_):
- if issubclass(type_, ARRAY):
- t1 = type_(String)
- else:
- t1 = type_()
+ t1 = _get_instance(type_)
gentype = t1.as_generic(allow_nulltype=False)
assert isinstance(gentype, TypeEngine)