summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-08-02 03:18:12 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-08-02 03:18:12 +0000
commit0afa28ee3a9699f7a03fbcc8baf01ea873d0c425 (patch)
treea7d300da26240b5b5bdb4b4e6b75e37cf0169a6e
parent5186bea969a0663f535a5a2daa7667a124b5b29f (diff)
parent0721a6bede7222386b2a2508aac5590909fbb148 (diff)
downloadsqlalchemy-0afa28ee3a9699f7a03fbcc8baf01ea873d0c425.tar.gz
Merge "Genericize str() for types"
-rw-r--r--doc/build/changelog/unreleased_14/4262.rst10
-rw-r--r--lib/sqlalchemy/sql/compiler.py24
-rw-r--r--lib/sqlalchemy/sql/type_api.py7
-rw-r--r--test/dialect/mssql/test_types.py9
-rw-r--r--test/sql/test_compiler.py2
-rw-r--r--test/sql/test_metadata.py16
-rw-r--r--test/sql/test_types.py35
7 files changed, 82 insertions, 21 deletions
diff --git a/doc/build/changelog/unreleased_14/4262.rst b/doc/build/changelog/unreleased_14/4262.rst
new file mode 100644
index 000000000..8377daca0
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/4262.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, types
+ :tickets: 4262
+
+ Cleaned up the internal ``str()`` for datatypes so that all types produce a
+ string representation without any dialect present, including that it works
+ for third-party dialect types without that dialect being present. The
+ string representation defaults to being the UPPERCASE name of that type
+ with nothing else.
+
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 61e26b003..a8bd1de33 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -4270,6 +4270,14 @@ class GenericTypeCompiler(TypeCompiler):
class StrSQLTypeCompiler(GenericTypeCompiler):
+ def process(self, type_, **kw):
+ try:
+ _compiler_dispatch = type_._compiler_dispatch
+ except AttributeError:
+ return self._visit_unknown(type_, **kw)
+ else:
+ return _compiler_dispatch(self, **kw)
+
def __getattr__(self, key):
if key.startswith("visit_"):
return self._visit_unknown
@@ -4277,7 +4285,21 @@ class StrSQLTypeCompiler(GenericTypeCompiler):
raise AttributeError(key)
def _visit_unknown(self, type_, **kw):
- return "%s" % type_.__class__.__name__
+ if type_.__class__.__name__ == type_.__class__.__name__.upper():
+ return type_.__class__.__name__
+ else:
+ return repr(type_)
+
+ def visit_null(self, type_, **kw):
+ return "NULL"
+
+ def visit_user_defined(self, type_, **kw):
+ try:
+ get_col_spec = type_.get_col_spec
+ except AttributeError:
+ return repr(type_)
+ else:
+ return get_col_spec(**kw)
class IdentifierPreparer(object):
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 2d23c56e1..1284ef515 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -626,12 +626,7 @@ class TypeEngine(Traversible):
@util.preload_module("sqlalchemy.engine.default")
def _default_dialect(self):
default = util.preloaded.engine_default
- if self.__class__.__module__.startswith("sqlalchemy.dialects"):
- tokens = self.__class__.__module__.split(".")[0:3]
- mod = ".".join(tokens)
- return getattr(__import__(mod).dialects, tokens[-1]).dialect()
- else:
- return default.DefaultDialect()
+ return default.StrCompileDialect()
def __str__(self):
if util.py2k:
diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py
index e28a42498..399e0ca90 100644
--- a/test/dialect/mssql/test_types.py
+++ b/test/dialect/mssql/test_types.py
@@ -953,9 +953,14 @@ class TypeRoundTripTest(
)
for col, spec in zip(reflected_binary.c, columns):
eq_(
- str(col.type),
+ col.type.compile(dialect=mssql.dialect()),
spec[3],
- "column %s %s != %s" % (col.key, str(col.type), spec[3]),
+ "column %s %s != %s"
+ % (
+ col.key,
+ col.type.compile(dialect=mssql.dialect()),
+ spec[3],
+ ),
)
c1 = testing.db.dialect.type_descriptor(col.type).__class__
c2 = testing.db.dialect.type_descriptor(
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index d79d00555..1d31f1ea5 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -3940,7 +3940,7 @@ class StringifySpecialTest(fixtures.TestBase):
eq_ignore_whitespace(
str(stmt),
- "SELECT CAST(mytable.myid AS MyType) AS myid FROM mytable",
+ "SELECT CAST(mytable.myid AS MyType()) AS myid FROM mytable",
)
def test_within_group(self):
diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py
index 3303eac1d..dc4e342fd 100644
--- a/test/sql/test_metadata.py
+++ b/test/sql/test_metadata.py
@@ -3970,13 +3970,9 @@ class ColumnOptionsTest(fixtures.TestBase):
assert Column(String, default=g2).default is g2
assert Column(String, onupdate=g2).onupdate is g2
- def _null_type_error(self, col):
- t = Table("t", MetaData(), col)
- assert_raises_message(
- exc.CompileError,
- r"\(in table 't', column 'foo'\): Can't generate DDL for NullType",
- schema.CreateTable(t).compile,
- )
+ def _null_type_no_error(self, col):
+ c_str = str(schema.CreateColumn(col).compile())
+ assert "NULL" in c_str
def _no_name_error(self, col):
assert_raises_message(
@@ -3997,13 +3993,13 @@ class ColumnOptionsTest(fixtures.TestBase):
def test_argument_signatures(self):
self._no_name_error(Column())
- self._null_type_error(Column("foo"))
+ self._null_type_no_error(Column("foo"))
self._no_name_error(Column(default="foo"))
self._no_name_error(Column(Sequence("a")))
- self._null_type_error(Column("foo", default="foo"))
+ self._null_type_no_error(Column("foo", default="foo"))
- self._null_type_error(Column("foo", Sequence("a")))
+ self._null_type_no_error(Column("foo", Sequence("a")))
self._no_name_error(Column(ForeignKey("bar.id")))
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 30e5d1fca..fac9fd139 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -294,6 +294,33 @@ class AdaptTest(fixtures.TestBase):
t1 = typ()
repr(t1)
+ @testing.uses_deprecated()
+ @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
+ def test_str(self, typ):
+ if issubclass(typ, ARRAY):
+ t1 = typ(String)
+ else:
+ t1 = typ()
+ str(t1)
+
+ def test_str_third_party(self):
+ class TINYINT(types.TypeEngine):
+ __visit_name__ = "TINYINT"
+
+ eq_(str(TINYINT()), "TINYINT")
+
+ def test_str_third_party_uppercase_no_visit_name(self):
+ class TINYINT(types.TypeEngine):
+ pass
+
+ eq_(str(TINYINT()), "TINYINT")
+
+ def test_str_third_party_camelcase_no_visit_name(self):
+ class TinyInt(types.TypeEngine):
+ pass
+
+ eq_(str(TinyInt()), "TinyInt()")
+
def test_adapt_constructor_copy_override_kw(self):
"""test that adapt() can accept kw args that override
the state of the original object.
@@ -2878,10 +2905,16 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_default_compile_mysql_integer(self):
self.assert_compile(
dialects.mysql.INTEGER(display_width=5),
- "INTEGER(5)",
+ "INTEGER",
allow_dialect_select=True,
)
+ self.assert_compile(
+ dialects.mysql.INTEGER(display_width=5),
+ "INTEGER(5)",
+ dialect="mysql",
+ )
+
def test_numeric_plain(self):
self.assert_compile(types.NUMERIC(), "NUMERIC")