diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-08-02 03:18:12 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-08-02 03:18:12 +0000 |
commit | 0afa28ee3a9699f7a03fbcc8baf01ea873d0c425 (patch) | |
tree | a7d300da26240b5b5bdb4b4e6b75e37cf0169a6e | |
parent | 5186bea969a0663f535a5a2daa7667a124b5b29f (diff) | |
parent | 0721a6bede7222386b2a2508aac5590909fbb148 (diff) | |
download | sqlalchemy-0afa28ee3a9699f7a03fbcc8baf01ea873d0c425.tar.gz |
Merge "Genericize str() for types"
-rw-r--r-- | doc/build/changelog/unreleased_14/4262.rst | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 7 | ||||
-rw-r--r-- | test/dialect/mssql/test_types.py | 9 | ||||
-rw-r--r-- | test/sql/test_compiler.py | 2 | ||||
-rw-r--r-- | test/sql/test_metadata.py | 16 | ||||
-rw-r--r-- | test/sql/test_types.py | 35 |
7 files changed, 82 insertions, 21 deletions
diff --git a/doc/build/changelog/unreleased_14/4262.rst b/doc/build/changelog/unreleased_14/4262.rst new file mode 100644 index 000000000..8377daca0 --- /dev/null +++ b/doc/build/changelog/unreleased_14/4262.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, types + :tickets: 4262 + + Cleaned up the internal ``str()`` for datatypes so that all types produce a + string representation without any dialect present, including that it works + for third-party dialect types without that dialect being present. The + string representation defaults to being the UPPERCASE name of that type + with nothing else. + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 61e26b003..a8bd1de33 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -4270,6 +4270,14 @@ class GenericTypeCompiler(TypeCompiler): class StrSQLTypeCompiler(GenericTypeCompiler): + def process(self, type_, **kw): + try: + _compiler_dispatch = type_._compiler_dispatch + except AttributeError: + return self._visit_unknown(type_, **kw) + else: + return _compiler_dispatch(self, **kw) + def __getattr__(self, key): if key.startswith("visit_"): return self._visit_unknown @@ -4277,7 +4285,21 @@ class StrSQLTypeCompiler(GenericTypeCompiler): raise AttributeError(key) def _visit_unknown(self, type_, **kw): - return "%s" % type_.__class__.__name__ + if type_.__class__.__name__ == type_.__class__.__name__.upper(): + return type_.__class__.__name__ + else: + return repr(type_) + + def visit_null(self, type_, **kw): + return "NULL" + + def visit_user_defined(self, type_, **kw): + try: + get_col_spec = type_.get_col_spec + except AttributeError: + return repr(type_) + else: + return get_col_spec(**kw) class IdentifierPreparer(object): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 2d23c56e1..1284ef515 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -626,12 +626,7 @@ class TypeEngine(Traversible): @util.preload_module("sqlalchemy.engine.default") def _default_dialect(self): default = util.preloaded.engine_default - if self.__class__.__module__.startswith("sqlalchemy.dialects"): - tokens = self.__class__.__module__.split(".")[0:3] - mod = ".".join(tokens) - return getattr(__import__(mod).dialects, tokens[-1]).dialect() - else: - return default.DefaultDialect() + return default.StrCompileDialect() def __str__(self): if util.py2k: diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py index e28a42498..399e0ca90 100644 --- a/test/dialect/mssql/test_types.py +++ b/test/dialect/mssql/test_types.py @@ -953,9 +953,14 @@ class TypeRoundTripTest( ) for col, spec in zip(reflected_binary.c, columns): eq_( - str(col.type), + col.type.compile(dialect=mssql.dialect()), spec[3], - "column %s %s != %s" % (col.key, str(col.type), spec[3]), + "column %s %s != %s" + % ( + col.key, + col.type.compile(dialect=mssql.dialect()), + spec[3], + ), ) c1 = testing.db.dialect.type_descriptor(col.type).__class__ c2 = testing.db.dialect.type_descriptor( diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index d79d00555..1d31f1ea5 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -3940,7 +3940,7 @@ class StringifySpecialTest(fixtures.TestBase): eq_ignore_whitespace( str(stmt), - "SELECT CAST(mytable.myid AS MyType) AS myid FROM mytable", + "SELECT CAST(mytable.myid AS MyType()) AS myid FROM mytable", ) def test_within_group(self): diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 3303eac1d..dc4e342fd 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -3970,13 +3970,9 @@ class ColumnOptionsTest(fixtures.TestBase): assert Column(String, default=g2).default is g2 assert Column(String, onupdate=g2).onupdate is g2 - def _null_type_error(self, col): - t = Table("t", MetaData(), col) - assert_raises_message( - exc.CompileError, - r"\(in table 't', column 'foo'\): Can't generate DDL for NullType", - schema.CreateTable(t).compile, - ) + def _null_type_no_error(self, col): + c_str = str(schema.CreateColumn(col).compile()) + assert "NULL" in c_str def _no_name_error(self, col): assert_raises_message( @@ -3997,13 +3993,13 @@ class ColumnOptionsTest(fixtures.TestBase): def test_argument_signatures(self): self._no_name_error(Column()) - self._null_type_error(Column("foo")) + self._null_type_no_error(Column("foo")) self._no_name_error(Column(default="foo")) self._no_name_error(Column(Sequence("a"))) - self._null_type_error(Column("foo", default="foo")) + self._null_type_no_error(Column("foo", default="foo")) - self._null_type_error(Column("foo", Sequence("a"))) + self._null_type_no_error(Column("foo", Sequence("a"))) self._no_name_error(Column(ForeignKey("bar.id"))) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 30e5d1fca..fac9fd139 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -294,6 +294,33 @@ class AdaptTest(fixtures.TestBase): t1 = typ() repr(t1) + @testing.uses_deprecated() + @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)]) + def test_str(self, typ): + if issubclass(typ, ARRAY): + t1 = typ(String) + else: + t1 = typ() + str(t1) + + def test_str_third_party(self): + class TINYINT(types.TypeEngine): + __visit_name__ = "TINYINT" + + eq_(str(TINYINT()), "TINYINT") + + def test_str_third_party_uppercase_no_visit_name(self): + class TINYINT(types.TypeEngine): + pass + + eq_(str(TINYINT()), "TINYINT") + + def test_str_third_party_camelcase_no_visit_name(self): + class TinyInt(types.TypeEngine): + pass + + eq_(str(TinyInt()), "TinyInt()") + def test_adapt_constructor_copy_override_kw(self): """test that adapt() can accept kw args that override the state of the original object. @@ -2878,10 +2905,16 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_default_compile_mysql_integer(self): self.assert_compile( dialects.mysql.INTEGER(display_width=5), - "INTEGER(5)", + "INTEGER", allow_dialect_select=True, ) + self.assert_compile( + dialects.mysql.INTEGER(display_width=5), + "INTEGER(5)", + dialect="mysql", + ) + def test_numeric_plain(self): self.assert_compile(types.NUMERIC(), "NUMERIC") |