diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-10-14 16:12:54 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-10-14 16:12:54 -0400 |
| commit | 92534dc8f30d173deaa1221a6872fd9b7ceae325 (patch) | |
| tree | ca8d70482cdcf9188e77d05812f0b59ec9ebbe2d /lib/sqlalchemy | |
| parent | 78a38967c4ad94194308f77f60a922236cd75227 (diff) | |
| download | sqlalchemy-92534dc8f30d173deaa1221a6872fd9b7ceae325.tar.gz | |
The MySQL :class:`.mysql.SET` type now features the same auto-quoting
behavior as that of :class:`.mysql.ENUM`. Quotes are not required when
setting up the value, but quotes that are present will be auto-detected
along with a warning. This also helps with Alembic where
the SET type doesn't render with quotes. [ticket:2817]
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 153 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 48 |
3 files changed, 115 insertions, 88 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 3bbd88d52..d0f654fe2 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -984,8 +984,49 @@ class LONGBLOB(sqltypes._Binary): __visit_name__ = 'LONGBLOB' +class _EnumeratedValues(_StringType): + def _init_values(self, values, kw): + self.quoting = kw.pop('quoting', 'auto') + + if self.quoting == 'auto' and len(values): + # What quoting character are we using? + q = None + for e in values: + if len(e) == 0: + self.quoting = 'unquoted' + break + elif q is None: + q = e[0] + + if len(e) == 1 or e[0] != q or e[-1] != q: + self.quoting = 'unquoted' + break + else: + self.quoting = 'quoted' + + if self.quoting == 'quoted': + util.warn_deprecated( + 'Manually quoting %s value literals is deprecated. Supply ' + 'unquoted values and use the quoting= option in cases of ' + 'ambiguity.' % self.__class__.__name__) + + values = self._strip_values(values) + + self._enumerated_values = values + length = max([len(v) for v in values] + [0]) + return values, length -class ENUM(sqltypes.Enum, _StringType): + @classmethod + def _strip_values(cls, values): + strip_values = [] + for a in values: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_values.append(a) + return strip_values + +class ENUM(sqltypes.Enum, _EnumeratedValues): """MySQL ENUM type.""" __visit_name__ = 'ENUM' @@ -993,9 +1034,9 @@ class ENUM(sqltypes.Enum, _StringType): def __init__(self, *enums, **kw): """Construct an ENUM. - Example: + E.g.:: - Column('myenum', MSEnum("foo", "bar", "baz")) + Column('myenum', ENUM("foo", "bar", "baz")) :param enums: The range of valid values for this ENUM. Values will be quoted when generating the schema according to the quoting flag (see @@ -1039,33 +1080,8 @@ class ENUM(sqltypes.Enum, _StringType): literals for you. This is a transitional option. """ - self.quoting = kw.pop('quoting', 'auto') - - if self.quoting == 'auto' and len(enums): - # What quoting character are we using? - q = None - for e in enums: - if len(e) == 0: - self.quoting = 'unquoted' - break - elif q is None: - q = e[0] - - if e[0] != q or e[-1] != q: - self.quoting = 'unquoted' - break - else: - self.quoting = 'quoted' - - if self.quoting == 'quoted': - util.warn_deprecated( - 'Manually quoting ENUM value literals is deprecated. Supply ' - 'unquoted values and use the quoting= option in cases of ' - 'ambiguity.') - enums = self._strip_enums(enums) - + values, length = self._init_values(enums, kw) self.strict = kw.pop('strict', False) - length = max([len(v) for v in enums] + [0]) kw.pop('metadata', None) kw.pop('schema', None) kw.pop('name', None) @@ -1073,17 +1089,7 @@ class ENUM(sqltypes.Enum, _StringType): kw.pop('native_enum', None) kw.pop('inherit_schema', None) _StringType.__init__(self, length=length, **kw) - sqltypes.Enum.__init__(self, *enums) - - @classmethod - def _strip_enums(cls, enums): - strip_enums = [] - for a in enums: - if a[0:1] == '"' or a[0:1] == "'": - # strip enclosing quotes and unquote interior - a = a[1:-1].replace(a[0] * 2, a[0]) - strip_enums.append(a) - return strip_enums + sqltypes.Enum.__init__(self, *values) def bind_processor(self, dialect): super_convert = super(ENUM, self).bind_processor(dialect) @@ -1103,7 +1109,7 @@ class ENUM(sqltypes.Enum, _StringType): return sqltypes.Enum.adapt(self, impltype, **kw) -class SET(_StringType): +class SET(_EnumeratedValues): """MySQL SET type.""" __visit_name__ = 'SET' @@ -1111,15 +1117,16 @@ class SET(_StringType): def __init__(self, *values, **kw): """Construct a SET. - Example:: + E.g.:: - Column('myset', MSSet("'foo'", "'bar'", "'baz'")) + Column('myset', SET("foo", "bar", "baz")) :param values: The range of valid values for this SET. Values will be - used exactly as they appear when generating schemas. Strings must - be quoted, as in the example above. Single-quotes are suggested for - ANSI compatibility and are required for portability to servers with - ANSI_QUOTES enabled. + quoted when generating the schema according to the quoting flag (see + below). + + .. versionchanged:: 0.9.0 quoting is applied automatically to + :class:`.mysql.SET` in the same way as for :class:`.mysql.ENUM`. :param charset: Optional, a column-level character set for this string value. Takes precedence to 'ascii' or 'unicode' short-hand. @@ -1138,18 +1145,27 @@ class SET(_StringType): BINARY in schema. This does not affect the type of data stored, only the collation of character data. - """ - self._ddl_values = values + :param quoting: Defaults to 'auto': automatically determine enum value + quoting. If all enum values are surrounded by the same quoting + character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. - strip_values = [] - for a in values: - if a[0:1] == '"' or a[0:1] == "'": - # strip enclosing quotes and unquote interior - a = a[1:-1].replace(a[0] * 2, a[0]) - strip_values.append(a) + 'quoted': values in enums are already quoted, they will be used + directly when generating the schema - this usage is deprecated. + + 'unquoted': values in enums are not quoted, they will be escaped and + surrounded by single quotes when generating the schema. - self.values = strip_values - kw.setdefault('length', max([len(v) for v in strip_values] + [0])) + Previous versions of this type always required manually quoted + values to be supplied; future versions will always quote the string + literals for you. This is a transitional option. + + .. versionadded:: 0.9.0 + + """ + values, length = self._init_values(values, kw) + self.values = tuple(values) + + kw.setdefault('length', length) super(SET, self).__init__(**kw) def result_processor(self, dialect, coltype): @@ -1830,7 +1846,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): if not type_.native_enum: return super(MySQLTypeCompiler, self).visit_enum(type_) else: - return self.visit_ENUM(type_) + return self._visit_enumerated_values("ENUM", type_, type_.enums) def visit_BLOB(self, type_): if type_.length: @@ -1847,16 +1863,21 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_LONGBLOB(self, type_): return "LONGBLOB" - def visit_ENUM(self, type_): + def _visit_enumerated_values(self, name, type_, enumerated_values): quoted_enums = [] - for e in type_.enums: + for e in enumerated_values: quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend_string(type_, {}, "ENUM(%s)" % - ",".join(quoted_enums)) + return self._extend_string(type_, {}, "%s(%s)" % ( + name, ",".join(quoted_enums)) + ) + + def visit_ENUM(self, type_): + return self._visit_enumerated_values("ENUM", type_, + type_._enumerated_values) def visit_SET(self, type_): - return self._extend_string(type_, {}, "SET(%s)" % - ",".join(type_._ddl_values)) + return self._visit_enumerated_values("SET", type_, + type_._enumerated_values) def visit_BOOLEAN(self, type): return "BOOL" @@ -2572,8 +2593,8 @@ class MySQLTableDefinitionParser(object): if spec.get(kw, False): type_kw[kw] = spec[kw] - if type_ == 'enum': - type_args = ENUM._strip_enums(type_args) + if issubclass(col_type, _EnumeratedValues): + type_args = _EnumeratedValues._strip_values(type_args) type_instance = col_type(*type_args, **type_kw) diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index a87829499..90512e41a 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -11,7 +11,7 @@ from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ eq_, ne_, is_, is_not_, startswith_, assert_raises, \ assert_raises_message, AssertsCompiledSQL, ComparesTables, \ - AssertsExecutionResults + AssertsExecutionResults, expect_deprecated from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 96a8bc023..062fffb18 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -92,30 +92,36 @@ def uses_deprecated(*messages): @decorator def decorate(fn, *args, **kw): - # todo: should probably be strict about this, too - filters = [dict(action='ignore', - category=sa_exc.SAPendingDeprecationWarning)] - if not messages: - filters.append(dict(action='ignore', - category=sa_exc.SADeprecationWarning)) - else: - filters.extend( - [dict(action='ignore', - message=message, - category=sa_exc.SADeprecationWarning) - for message in - [(m.startswith('//') and - ('Call to deprecated function ' + m[2:]) or m) - for m in messages]]) - - for f in filters: - warnings.filterwarnings(**f) - try: + with expect_deprecated(*messages): return fn(*args, **kw) - finally: - resetwarnings() return decorate +@contextlib.contextmanager +def expect_deprecated(*messages): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SADeprecationWarning)) + else: + filters.extend( + [dict(action='ignore', + message=message, + category=sa_exc.SADeprecationWarning) + for message in + [(m.startswith('//') and + ('Call to deprecated function ' + m[2:]) or m) + for m in messages]]) + + for f in filters: + warnings.filterwarnings(**f) + try: + yield + finally: + resetwarnings() + + def global_cleanup_assertions(): """Check things that have to be finalized at the end of a test suite. |
