diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/aaa_profiling/test_memusage.py | 22 | ||||
-rw-r--r-- | test/dialect/postgresql/test_compiler.py | 76 | ||||
-rw-r--r-- | test/dialect/postgresql/test_reflection.py | 126 | ||||
-rw-r--r-- | test/dialect/postgresql/test_types.py | 371 | ||||
-rw-r--r-- | test/sql/test_types.py | 45 |
5 files changed, 507 insertions, 133 deletions
diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index e90c428f7..dc6dcf060 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -312,20 +312,22 @@ class MemUsageTest(EnsureZeroed): eng = engines.testing_engine() for args in ( - (types.Integer,), - (types.String,), - (types.PickleType,), - (types.Enum, "a", "b", "c"), - (sqlite.DATETIME,), - (postgresql.ENUM, "a", "b", "c"), - (types.Interval,), - (postgresql.INTERVAL,), - (mysql.VARCHAR,), + (types.Integer, {}), + (types.String, {}), + (types.PickleType, {}), + (types.Enum, "a", "b", "c", {}), + (sqlite.DATETIME, {}), + (postgresql.ENUM, "a", "b", "c", {"name": "pgenum"}), + (types.Interval, {}), + (postgresql.INTERVAL, {}), + (mysql.VARCHAR, {}), ): @profile_memory() def go(): - type_ = args[0](*args[1:]) + kwargs = args[-1] + posargs = args[1:-1] + type_ = args[0](*posargs, **kwargs) bp = type_._cached_bind_processor(eng.dialect) rp = type_._cached_result_processor(eng.dialect, 0) bp, rp # strong reference diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 25550afe1..9be76130d 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -38,6 +38,7 @@ from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg +from sqlalchemy.dialects.postgresql import DOMAIN from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import TSRANGE @@ -270,7 +271,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): render_schema_translate=True, ) - def test_create_type_schema_translate(self): + def test_create_enum_schema_translate(self): e1 = Enum("x", "y", "z", name="somename") e2 = Enum("x", "y", "z", name="somename", schema="someschema") schema_translate_map = {None: "foo", "someschema": "bar"} @@ -289,6 +290,79 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): render_schema_translate=True, ) + def test_domain(self): + self.assert_compile( + postgresql.CreateDomainType( + DOMAIN( + "x", + Integer, + default=text("11"), + not_null=True, + check="VALUE < 0", + ) + ), + "CREATE DOMAIN x AS INTEGER DEFAULT 11 NOT NULL CHECK (VALUE < 0)", + ) + self.assert_compile( + postgresql.CreateDomainType( + DOMAIN( + "sOmEnAmE", + Text, + collation="utf8", + constraint_name="a constraint", + not_null=True, + ) + ), + 'CREATE DOMAIN "sOmEnAmE" AS TEXT COLLATE utf8 CONSTRAINT ' + '"a constraint" NOT NULL', + ) + self.assert_compile( + postgresql.CreateDomainType( + DOMAIN( + "foo", + Text, + collation="utf8", + default="foobar", + constraint_name="no_bar", + not_null=True, + check="VALUE != 'bar'", + ) + ), + "CREATE DOMAIN foo AS TEXT COLLATE utf8 DEFAULT 'foobar' " + "CONSTRAINT no_bar NOT NULL CHECK (VALUE != 'bar')", + ) + + def test_cast_domain_schema(self): + """test #6739""" + d1 = DOMAIN("somename", Integer) + d2 = DOMAIN("somename", Integer, schema="someschema") + + stmt = select(cast(column("foo"), d1), cast(column("bar"), d2)) + self.assert_compile( + stmt, + "SELECT CAST(foo AS somename) AS foo, " + "CAST(bar AS someschema.somename) AS bar", + ) + + def test_create_domain_schema_translate(self): + d1 = DOMAIN("somename", Integer) + d2 = DOMAIN("somename", Integer, schema="someschema") + schema_translate_map = {None: "foo", "someschema": "bar"} + + self.assert_compile( + postgresql.CreateDomainType(d1), + "CREATE DOMAIN foo.somename AS INTEGER ", + schema_translate_map=schema_translate_map, + render_schema_translate=True, + ) + + self.assert_compile( + postgresql.CreateDomainType(d2), + "CREATE DOMAIN bar.somename AS INTEGER ", + schema_translate_map=schema_translate_map, + render_schema_translate=True, + ) + def test_create_table_with_schema_type_schema_translate(self): e1 = Enum("x", "y", "z", name="somename") e2 = Enum("x", "y", "z", name="somename", schema="someschema") diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 21b4149bc..99bc14d78 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -410,6 +410,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): "CREATE DOMAIN nullable_domain AS TEXT CHECK " "(VALUE IN('FOO', 'BAR'))", "CREATE DOMAIN not_nullable_domain AS TEXT NOT NULL", + "CREATE DOMAIN my_int AS int CONSTRAINT b_my_int_one CHECK " + "(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) " + "CHECK(VALUE != 22)", ]: try: con.exec_driver_sql(ddl) @@ -468,6 +471,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): con.exec_driver_sql("DROP TABLE nullable_domain_test") con.exec_driver_sql("DROP DOMAIN nullable_domain") con.exec_driver_sql("DROP DOMAIN not_nullable_domain") + con.exec_driver_sql("DROP DOMAIN my_int") def test_table_is_reflected(self, connection): metadata = MetaData() @@ -579,6 +583,122 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): finally: base.PGDialect.ischema_names = ischema_names + @property + def all_domains(self): + return { + "public": [ + { + "visible": True, + "name": "arraydomain", + "schema": "public", + "nullable": True, + "type": "integer[]", + "default": None, + "constraints": [], + }, + { + "visible": True, + "name": "enumdomain", + "schema": "public", + "nullable": True, + "type": "testtype", + "default": None, + "constraints": [], + }, + { + "visible": True, + "name": "my_int", + "schema": "public", + "nullable": True, + "type": "integer", + "default": None, + "constraints": [ + {"check": "VALUE < 42", "name": "a_my_int_two"}, + {"check": "VALUE > 1", "name": "b_my_int_one"}, + # autogenerated name by pg + {"check": "VALUE <> 22", "name": "my_int_check"}, + ], + }, + { + "visible": True, + "name": "not_nullable_domain", + "schema": "public", + "nullable": False, + "type": "text", + "default": None, + "constraints": [], + }, + { + "visible": True, + "name": "nullable_domain", + "schema": "public", + "nullable": True, + "type": "text", + "default": None, + "constraints": [ + { + "check": "VALUE = ANY (ARRAY['FOO'::text, " + "'BAR'::text])", + # autogenerated name by pg + "name": "nullable_domain_check", + } + ], + }, + { + "visible": True, + "name": "testdomain", + "schema": "public", + "nullable": False, + "type": "integer", + "default": "42", + "constraints": [], + }, + ], + "test_schema": [ + { + "visible": False, + "name": "testdomain", + "schema": "test_schema", + "nullable": True, + "type": "integer", + "default": "0", + "constraints": [], + } + ], + "SomeSchema": [ + { + "visible": False, + "name": "Quoted.Domain", + "schema": "SomeSchema", + "nullable": True, + "type": "integer", + "default": "0", + "constraints": [], + } + ], + } + + def test_inspect_domains(self, connection): + inspector = inspect(connection) + eq_(inspector.get_domains(), self.all_domains["public"]) + + def test_inspect_domains_schema(self, connection): + inspector = inspect(connection) + eq_( + inspector.get_domains("test_schema"), + self.all_domains["test_schema"], + ) + eq_( + inspector.get_domains("SomeSchema"), self.all_domains["SomeSchema"] + ) + + def test_inspect_domains_star(self, connection): + inspector = inspect(connection) + all_ = [d for dl in self.all_domains.values() for d in dl] + all_ += inspector.get_domains("information_schema") + exp = sorted(all_, key=lambda d: (d["schema"], d["name"])) + eq_(inspector.get_domains("*"), exp) + class ReflectionTest( ReflectionFixtures, AssertsCompiledSQL, fixtures.TestBase @@ -1800,10 +1920,10 @@ class ReflectionTest( eq_( check_constraints, { - "cc1": "(a > 1) AND (a < 5)", - "cc2": "(a = 1) OR ((a > 2) AND (a < 5))", + "cc1": "a > 1 AND a < 5", + "cc2": "a = 1 OR a > 2 AND a < 5", "cc3": "is_positive(a)", - "cc4": "(b)::text <> 'hi\nim a name \nyup\n'::text", + "cc4": "b::text <> 'hi\nim a name \nyup\n'::text", }, ) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 79f029391..5c3935d44 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -38,12 +38,15 @@ from sqlalchemy import util from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import DATERANGE +from sqlalchemy.dialects.postgresql import DOMAIN +from sqlalchemy.dialects.postgresql import ENUM from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import hstore from sqlalchemy.dialects.postgresql import INT4RANGE from sqlalchemy.dialects.postgresql import INT8RANGE from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import NamedType from sqlalchemy.dialects.postgresql import NUMRANGE from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql import TSTZRANGE @@ -161,7 +164,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): eq_(row, ([5], [5], [6], [7], [decimal.Decimal("6.4")])) -class EnumTest(fixtures.TestBase, AssertsExecutionResults): +class NamedTypeTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True __only_on__ = "postgresql > 8.3" @@ -173,16 +176,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): "the native_enum flag does not apply to the " "sqlalchemy.dialects.postgresql.ENUM datatype;" ): - e1 = postgresql.ENUM("a", "b", "c", native_enum=False) + e1 = postgresql.ENUM( + "a", "b", "c", name="pgenum", native_enum=False + ) - e2 = postgresql.ENUM("a", "b", "c", native_enum=True) - e3 = postgresql.ENUM("a", "b", "c") + e2 = postgresql.ENUM("a", "b", "c", name="pgenum", native_enum=True) + e3 = postgresql.ENUM("a", "b", "c", name="pgenum") is_(e1.native_enum, True) is_(e2.native_enum, True) is_(e3.native_enum, True) - def test_create_table(self, metadata, connection): + def test_enum_create_table(self, metadata, connection): metadata = self.metadata t1 = Table( "table", @@ -202,50 +207,147 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): [(1, "two"), (2, "three"), (3, "three")], ) + def test_domain_create_table(self, metadata, connection): + metadata = self.metadata + Email = DOMAIN( + name="email", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + ) + PosInt = DOMAIN( + name="pos_int", + data_type=Integer, + not_null=True, + check=r"VALUE > 0", + ) + t1 = Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("email", Email), + Column("number", PosInt), + ) + t1.create(connection) + t1.create(connection, checkfirst=True) # check the create + connection.execute( + t1.insert(), {"email": "test@example.com", "number": 42} + ) + connection.execute(t1.insert(), {"email": "a@b.c", "number": 1}) + connection.execute( + t1.insert(), {"email": "example@gmail.co.uk", "number": 99} + ) + eq_( + connection.execute(t1.select().order_by(t1.c.id)).fetchall(), + [ + (1, "test@example.com", 42), + (2, "a@b.c", 1), + (3, "example@gmail.co.uk", 99), + ], + ) + + @testing.combinations( + (ENUM("one", "two", "three", name="mytype"), "get_enums"), + ( + DOMAIN( + name="mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + ), + "get_domains", + ), + argnames="datatype, method", + ) + def test_drops_on_table( + self, connection, metadata, datatype: "NamedType", method + ): + table = Table("e1", metadata, Column("e1", datatype)) + + table.create(connection) + table.drop(connection) + + assert "mytype" not in [ + e["name"] for e in getattr(inspect(connection), method)() + ] + table.create(connection) + assert "mytype" in [ + e["name"] for e in getattr(inspect(connection), method)() + ] + table.drop(connection) + assert "mytype" not in [ + e["name"] for e in getattr(inspect(connection), method)() + ] + + @testing.combinations( + ( + lambda symbol_name: ENUM( + "one", "two", "three", name="schema_mytype", schema=symbol_name + ), + ["two", "three", "three"], + "get_enums", + ), + ( + lambda symbol_name: DOMAIN( + name="schema_mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + schema=symbol_name, + ), + ["test@example.com", "a@b.c", "example@gmail.co.uk"], + "get_domains", + ), + argnames="datatype,data,method", + ) @testing.combinations(None, "foo", argnames="symbol_name") - def test_create_table_schema_translate_map(self, connection, symbol_name): + def test_create_table_schema_translate_map( + self, connection, symbol_name, datatype, data, method + ): # note we can't use the fixture here because it will not drop # from the correct schema metadata = MetaData() + dt = datatype(symbol_name) + t1 = Table( "table", metadata, Column("id", Integer, primary_key=True), - Column( - "value", - Enum( - "one", - "two", - "three", - name="schema_enum", - schema=symbol_name, - ), - ), + Column("value", dt), schema=symbol_name, ) conn = connection.execution_options( schema_translate_map={symbol_name: testing.config.test_schema} ) t1.create(conn) - assert "schema_enum" in [ + assert "schema_mytype" in [ e["name"] - for e in inspect(conn).get_enums(schema=testing.config.test_schema) + for e in getattr(inspect(conn), method)( + schema=testing.config.test_schema + ) ] t1.create(conn, checkfirst=True) - conn.execute(t1.insert(), dict(value="two")) - conn.execute(t1.insert(), dict(value="three")) - conn.execute(t1.insert(), dict(value="three")) + conn.execute( + t1.insert(), + dict(value=data[0]), + ) + conn.execute(t1.insert(), dict(value=data[1])) + conn.execute(t1.insert(), dict(value=data[2])) eq_( conn.execute(t1.select().order_by(t1.c.id)).fetchall(), - [(1, "two"), (2, "three"), (3, "three")], + [ + (1, data[0]), + (2, data[1]), + (3, data[2]), + ], ) t1.drop(conn) - assert "schema_enum" not in [ + + assert "schema_mytype" not in [ e["name"] - for e in inspect(conn).get_enums(schema=testing.config.test_schema) + for e in getattr(inspect(conn), method)( + schema=testing.config.test_schema + ) ] t1.drop(conn, checkfirst=True) @@ -256,40 +358,48 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ("override_metadata_schema",), argnames="test_case", ) + @testing.combinations("enum", "domain", argnames="datatype") @testing.requires.schemas - def test_schema_inheritance(self, test_case, metadata, connection): + def test_schema_inheritance( + self, test_case, metadata, connection, datatype + ): """test #6373""" metadata.schema = testing.config.test_schema + def make_type(**kw): + if datatype == "enum": + return Enum("four", "five", "six", name="mytype", **kw) + elif datatype == "domain": + return DOMAIN( + name="mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + **kw, + ) + else: + assert False + if test_case == "metadata_schema_only": - enum = Enum( - "four", "five", "six", metadata=metadata, name="myenum" - ) + enum = make_type(metadata=metadata) assert_schema = testing.config.test_schema elif test_case == "override_metadata_schema": - enum = Enum( - "four", - "five", - "six", + enum = make_type( metadata=metadata, schema=testing.config.test_schema_2, - name="myenum", ) assert_schema = testing.config.test_schema_2 elif test_case == "inherit_table_schema": - enum = Enum( - "four", - "five", - "six", + enum = make_type( metadata=metadata, inherit_schema=True, - name="myenum", ) assert_schema = testing.config.test_schema_2 elif test_case == "local_schema": - enum = Enum("four", "five", "six", name="myenum") + enum = make_type() assert_schema = testing.config.db.dialect.default_schema_name + else: + assert False Table( "t", @@ -300,27 +410,62 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): metadata.create_all(connection) - eq_( - inspect(connection).get_enums(schema=assert_schema), - [ - { - "labels": ["four", "five", "six"], - "name": "myenum", - "schema": assert_schema, - "visible": assert_schema - == testing.config.db.dialect.default_schema_name, - } - ], - ) + if datatype == "enum": + eq_( + inspect(connection).get_enums(schema=assert_schema), + [ + { + "labels": ["four", "five", "six"], + "name": "mytype", + "schema": assert_schema, + "visible": assert_schema + == testing.config.db.dialect.default_schema_name, + } + ], + ) + elif datatype == "domain": + + def_schame = testing.config.db.dialect.default_schema_name + eq_( + inspect(connection).get_domains(schema=assert_schema), + [ + { + "name": "mytype", + "type": "text", + "nullable": True, + "default": None, + "schema": assert_schema, + "visible": assert_schema == def_schame, + "constraints": [ + { + "name": "mytype_check", + "check": r"VALUE ~ '[^@]+@[^@]+\.[^@]+'::text", + } + ], + } + ], + ) + else: + assert False - def test_name_required(self, metadata, connection): - etype = Enum("four", "five", "six", metadata=metadata) - assert_raises(exc.CompileError, etype.create, connection) + @testing.combinations( + (ENUM("one", "two", "three", name=None)), + ( + DOMAIN( + name=None, + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + ), + ), + argnames="datatype", + ) + def test_name_required(self, metadata, connection, datatype): + assert_raises(exc.CompileError, datatype.create, connection) assert_raises( - exc.CompileError, etype.compile, dialect=connection.dialect + exc.CompileError, datatype.compile, dialect=connection.dialect ) - def test_unicode_labels(self, connection, metadata): + def test_enum_unicode_labels(self, connection, metadata): t1 = Table( "table", metadata, @@ -426,22 +571,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): connection.execute(t1.insert(), {"bar": "Ü"}) eq_(connection.scalar(select(t1.c.bar)), "Ü") - def test_disable_create(self, metadata, connection): + @testing.combinations( + (ENUM("one", "two", "three", name="mytype", create_type=False),), + ( + DOMAIN( + name="mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + create_type=False, + ), + ), + argnames="datatype", + ) + def test_disable_create(self, metadata, connection, datatype): metadata = self.metadata - e1 = postgresql.ENUM( - "one", "two", "three", name="myenum", create_type=False - ) - - t1 = Table("e1", metadata, Column("c1", e1)) + t1 = Table("e1", metadata, Column("c1", datatype)) # table can be created separately # without conflict - e1.create(bind=connection) + datatype.create(bind=connection) t1.create(connection) t1.drop(connection) - e1.drop(bind=connection) + datatype.drop(bind=connection) - def test_dont_keep_checking(self, metadata, connection): + def test_enum_dont_keep_checking(self, metadata, connection): metadata = self.metadata e1 = postgresql.ENUM("one", "two", "three", name="myenum") @@ -486,7 +639,36 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): RegexSQL("DROP TYPE myenum", dialect="postgresql"), ) - def test_generate_multiple(self, metadata, connection): + @testing.combinations( + ( + Enum( + "one", + "two", + "three", + name="mytype", + ), + "get_enums", + ), + ( + ENUM( + "one", + "two", + "three", + name="mytype", + ), + "get_enums", + ), + ( + DOMAIN( + name="mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + ), + "get_domains", + ), + argnames="datatype, method", + ) + def test_generate_multiple(self, metadata, connection, datatype, method): """Test that the same enum twice only generates once for the create_all() call, without using checkfirst. @@ -494,15 +676,20 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): now handles this. """ - e1 = Enum("one", "two", "three", name="myenum") - Table("e1", metadata, Column("c1", e1)) + Table("e1", metadata, Column("c1", datatype)) - Table("e2", metadata, Column("c1", e1)) + Table("e2", metadata, Column("c1", datatype)) metadata.create_all(connection, checkfirst=False) + + assert "mytype" in [ + e["name"] for e in getattr(inspect(connection), method)() + ] + metadata.drop_all(connection, checkfirst=False) - assert "myenum" not in [ - e["name"] for e in inspect(connection).get_enums() + + assert "mytype" not in [ + e["name"] for e in getattr(inspect(connection), method)() ] def test_generate_alone_on_metadata(self, connection, metadata): @@ -571,23 +758,6 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): for e in inspect(connection).get_enums(schema="test_schema") ] - def test_drops_on_table(self, connection, metadata): - - e1 = Enum("one", "two", "three", name="myenum") - table = Table("e1", metadata, Column("c1", e1)) - - table.create(connection) - table.drop(connection) - assert "myenum" not in [ - e["name"] for e in inspect(connection).get_enums() - ] - table.create(connection) - assert "myenum" in [e["name"] for e in inspect(connection).get_enums()] - table.drop(connection) - assert "myenum" not in [ - e["name"] for e in inspect(connection).get_enums() - ] - def test_create_drop_schema_translate_map(self, connection): conn = connection.execution_options( @@ -1445,15 +1615,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): array_agg, ) - element_type = ENUM if with_enum else Integer + element = ENUM(name="pgenum") if with_enum else Integer() + element_type = type(element) expr = ( array_agg( aggregate_order_by( - column("q", element_type), column("idx", Integer) + column("q", element), column("idx", Integer) ) ) if using_aggregate_order_by - else array_agg(column("q", element_type)) + else array_agg(column("q", element)) ) is_(expr.type.__class__, postgresql.ARRAY) is_(expr.type.item_type.__class__, element_type) @@ -2081,10 +2252,13 @@ class ArrayRoundTripTest: ], testing.requires.hstore, ), - (postgresql.ENUM(AnEnum), enum_values), + (postgresql.ENUM(AnEnum, name="pgenum"), enum_values), (sqltypes.Enum(AnEnum, native_enum=True), enum_values), (sqltypes.Enum(AnEnum, native_enum=False), enum_values), - (postgresql.ENUM(AnEnum, native_enum=True), enum_values), + ( + postgresql.ENUM(AnEnum, name="pgenum", native_enum=True), + enum_values, + ), ( make_difficult_enum(sqltypes.Enum, native=True), difficult_enum_values, @@ -2102,10 +2276,15 @@ class ArrayRoundTripTest: if not exclude_empty_lists: elements.extend( [ - (postgresql.ENUM(AnEnum), empty_list), + (postgresql.ENUM(AnEnum, name="pgenum"), empty_list), (sqltypes.Enum(AnEnum, native_enum=True), empty_list), (sqltypes.Enum(AnEnum, native_enum=False), empty_list), - (postgresql.ENUM(AnEnum, native_enum=True), empty_list), + ( + postgresql.ENUM( + AnEnum, name="pgenum", native_enum=True + ), + empty_list, + ), ] ) if not exclude_json: @@ -2410,7 +2589,7 @@ class ArrayEnum(fixtures.TestBase): ), Column( "pyenum_col", - array_cls(enum_cls(MyEnum)), + array_cls(enum_cls(MyEnum, name="pgenum")), ), ) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 04aa4e000..623688b83 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -111,7 +111,11 @@ def _all_dialects(): def _types_for_mod(mod): for key in dir(mod): typ = getattr(mod, key) - if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine): + if ( + not isinstance(typ, type) + or not issubclass(typ, types.TypeEngine) + or typ.__dict__.get("__abstract__") + ): continue yield typ @@ -143,6 +147,17 @@ def _all_types(omit_special_types=False): yield typ +def _get_instance(type_): + if issubclass(type_, ARRAY): + return type_(String) + elif hasattr(type_, "__test_init__"): + t1 = type_.__test_init__() + is_(isinstance(t1, type_), True) + return t1 + else: + return type_() + + class AdaptTest(fixtures.TestBase): @testing.combinations(((t,) for t in _types_for_mod(types)), id_="n") def test_uppercase_importable(self, typ): @@ -240,11 +255,8 @@ class AdaptTest(fixtures.TestBase): adapt() beyond their defaults. """ + t1 = _get_instance(typ) - if issubclass(typ, ARRAY): - t1 = typ(String) - else: - t1 = typ() for cls in target_adaptions: if (is_down_adaption and issubclass(typ, sqltypes.Emulated)) or ( not is_down_adaption and issubclass(cls, sqltypes.Emulated) @@ -301,19 +313,13 @@ class AdaptTest(fixtures.TestBase): @testing.uses_deprecated() @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)]) def test_repr(self, typ): - if issubclass(typ, ARRAY): - t1 = typ(String) - else: - t1 = typ() + t1 = _get_instance(typ) repr(t1) @testing.uses_deprecated() @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)]) def test_str(self, typ): - if issubclass(typ, ARRAY): - t1 = typ(String) - else: - t1 = typ() + t1 = _get_instance(typ) str(t1) def test_str_third_party(self): @@ -400,7 +406,7 @@ class AsGenericTest(fixtures.TestBase): (pg.JSON(), sa.JSON()), (pg.ARRAY(sa.String), sa.ARRAY(sa.String)), (Enum("a", "b", "c"), Enum("a", "b", "c")), - (pg.ENUM("a", "b", "c"), Enum("a", "b", "c")), + (pg.ENUM("a", "b", "c", name="pgenum"), Enum("a", "b", "c")), (mysql.ENUM("a", "b", "c"), Enum("a", "b", "c")), (pg.INTERVAL(precision=5), Interval(native=True, second_precision=5)), ( @@ -419,11 +425,7 @@ class AsGenericTest(fixtures.TestBase): ] ) def test_as_generic_all_types_heuristic(self, type_): - if issubclass(type_, ARRAY): - t1 = type_(String) - else: - t1 = type_() - + t1 = _get_instance(type_) try: gentype = t1.as_generic() except NotImplementedError: @@ -445,10 +447,7 @@ class AsGenericTest(fixtures.TestBase): ] ) def test_as_generic_all_types_custom(self, type_): - if issubclass(type_, ARRAY): - t1 = type_(String) - else: - t1 = type_() + t1 = _get_instance(type_) gentype = t1.as_generic(allow_nulltype=False) assert isinstance(gentype, TypeEngine) |