summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2013-10-14 16:12:54 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2013-10-14 16:12:54 -0400
commit92534dc8f30d173deaa1221a6872fd9b7ceae325 (patch)
treeca8d70482cdcf9188e77d05812f0b59ec9ebbe2d /lib/sqlalchemy
parent78a38967c4ad94194308f77f60a922236cd75227 (diff)
downloadsqlalchemy-92534dc8f30d173deaa1221a6872fd9b7ceae325.tar.gz
The MySQL :class:`.mysql.SET` type now features the same auto-quoting
behavior as that of :class:`.mysql.ENUM`. Quotes are not required when setting up the value, but quotes that are present will be auto-detected along with a warning. This also helps with Alembic where the SET type doesn't render with quotes. [ticket:2817]
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py153
-rw-r--r--lib/sqlalchemy/testing/__init__.py2
-rw-r--r--lib/sqlalchemy/testing/assertions.py48
3 files changed, 115 insertions, 88 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 3bbd88d52..d0f654fe2 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -984,8 +984,49 @@ class LONGBLOB(sqltypes._Binary):
__visit_name__ = 'LONGBLOB'
+class _EnumeratedValues(_StringType):
+ def _init_values(self, values, kw):
+ self.quoting = kw.pop('quoting', 'auto')
+
+ if self.quoting == 'auto' and len(values):
+ # What quoting character are we using?
+ q = None
+ for e in values:
+ if len(e) == 0:
+ self.quoting = 'unquoted'
+ break
+ elif q is None:
+ q = e[0]
+
+ if len(e) == 1 or e[0] != q or e[-1] != q:
+ self.quoting = 'unquoted'
+ break
+ else:
+ self.quoting = 'quoted'
+
+ if self.quoting == 'quoted':
+ util.warn_deprecated(
+ 'Manually quoting %s value literals is deprecated. Supply '
+ 'unquoted values and use the quoting= option in cases of '
+ 'ambiguity.' % self.__class__.__name__)
+
+ values = self._strip_values(values)
+
+ self._enumerated_values = values
+ length = max([len(v) for v in values] + [0])
+ return values, length
-class ENUM(sqltypes.Enum, _StringType):
+ @classmethod
+ def _strip_values(cls, values):
+ strip_values = []
+ for a in values:
+ if a[0:1] == '"' or a[0:1] == "'":
+ # strip enclosing quotes and unquote interior
+ a = a[1:-1].replace(a[0] * 2, a[0])
+ strip_values.append(a)
+ return strip_values
+
+class ENUM(sqltypes.Enum, _EnumeratedValues):
"""MySQL ENUM type."""
__visit_name__ = 'ENUM'
@@ -993,9 +1034,9 @@ class ENUM(sqltypes.Enum, _StringType):
def __init__(self, *enums, **kw):
"""Construct an ENUM.
- Example:
+ E.g.::
- Column('myenum', MSEnum("foo", "bar", "baz"))
+ Column('myenum', ENUM("foo", "bar", "baz"))
:param enums: The range of valid values for this ENUM. Values will be
quoted when generating the schema according to the quoting flag (see
@@ -1039,33 +1080,8 @@ class ENUM(sqltypes.Enum, _StringType):
literals for you. This is a transitional option.
"""
- self.quoting = kw.pop('quoting', 'auto')
-
- if self.quoting == 'auto' and len(enums):
- # What quoting character are we using?
- q = None
- for e in enums:
- if len(e) == 0:
- self.quoting = 'unquoted'
- break
- elif q is None:
- q = e[0]
-
- if e[0] != q or e[-1] != q:
- self.quoting = 'unquoted'
- break
- else:
- self.quoting = 'quoted'
-
- if self.quoting == 'quoted':
- util.warn_deprecated(
- 'Manually quoting ENUM value literals is deprecated. Supply '
- 'unquoted values and use the quoting= option in cases of '
- 'ambiguity.')
- enums = self._strip_enums(enums)
-
+ values, length = self._init_values(enums, kw)
self.strict = kw.pop('strict', False)
- length = max([len(v) for v in enums] + [0])
kw.pop('metadata', None)
kw.pop('schema', None)
kw.pop('name', None)
@@ -1073,17 +1089,7 @@ class ENUM(sqltypes.Enum, _StringType):
kw.pop('native_enum', None)
kw.pop('inherit_schema', None)
_StringType.__init__(self, length=length, **kw)
- sqltypes.Enum.__init__(self, *enums)
-
- @classmethod
- def _strip_enums(cls, enums):
- strip_enums = []
- for a in enums:
- if a[0:1] == '"' or a[0:1] == "'":
- # strip enclosing quotes and unquote interior
- a = a[1:-1].replace(a[0] * 2, a[0])
- strip_enums.append(a)
- return strip_enums
+ sqltypes.Enum.__init__(self, *values)
def bind_processor(self, dialect):
super_convert = super(ENUM, self).bind_processor(dialect)
@@ -1103,7 +1109,7 @@ class ENUM(sqltypes.Enum, _StringType):
return sqltypes.Enum.adapt(self, impltype, **kw)
-class SET(_StringType):
+class SET(_EnumeratedValues):
"""MySQL SET type."""
__visit_name__ = 'SET'
@@ -1111,15 +1117,16 @@ class SET(_StringType):
def __init__(self, *values, **kw):
"""Construct a SET.
- Example::
+ E.g.::
- Column('myset', MSSet("'foo'", "'bar'", "'baz'"))
+ Column('myset', SET("foo", "bar", "baz"))
:param values: The range of valid values for this SET. Values will be
- used exactly as they appear when generating schemas. Strings must
- be quoted, as in the example above. Single-quotes are suggested for
- ANSI compatibility and are required for portability to servers with
- ANSI_QUOTES enabled.
+ quoted when generating the schema according to the quoting flag (see
+ below).
+
+ .. versionchanged:: 0.9.0 quoting is applied automatically to
+ :class:`.mysql.SET` in the same way as for :class:`.mysql.ENUM`.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
@@ -1138,18 +1145,27 @@ class SET(_StringType):
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
- """
- self._ddl_values = values
+ :param quoting: Defaults to 'auto': automatically determine enum value
+ quoting. If all enum values are surrounded by the same quoting
+ character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
- strip_values = []
- for a in values:
- if a[0:1] == '"' or a[0:1] == "'":
- # strip enclosing quotes and unquote interior
- a = a[1:-1].replace(a[0] * 2, a[0])
- strip_values.append(a)
+ 'quoted': values in enums are already quoted, they will be used
+ directly when generating the schema - this usage is deprecated.
+
+ 'unquoted': values in enums are not quoted, they will be escaped and
+ surrounded by single quotes when generating the schema.
- self.values = strip_values
- kw.setdefault('length', max([len(v) for v in strip_values] + [0]))
+ Previous versions of this type always required manually quoted
+ values to be supplied; future versions will always quote the string
+ literals for you. This is a transitional option.
+
+ .. versionadded:: 0.9.0
+
+ """
+ values, length = self._init_values(values, kw)
+ self.values = tuple(values)
+
+ kw.setdefault('length', length)
super(SET, self).__init__(**kw)
def result_processor(self, dialect, coltype):
@@ -1830,7 +1846,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
if not type_.native_enum:
return super(MySQLTypeCompiler, self).visit_enum(type_)
else:
- return self.visit_ENUM(type_)
+ return self._visit_enumerated_values("ENUM", type_, type_.enums)
def visit_BLOB(self, type_):
if type_.length:
@@ -1847,16 +1863,21 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def visit_LONGBLOB(self, type_):
return "LONGBLOB"
- def visit_ENUM(self, type_):
+ def _visit_enumerated_values(self, name, type_, enumerated_values):
quoted_enums = []
- for e in type_.enums:
+ for e in enumerated_values:
quoted_enums.append("'%s'" % e.replace("'", "''"))
- return self._extend_string(type_, {}, "ENUM(%s)" %
- ",".join(quoted_enums))
+ return self._extend_string(type_, {}, "%s(%s)" % (
+ name, ",".join(quoted_enums))
+ )
+
+ def visit_ENUM(self, type_):
+ return self._visit_enumerated_values("ENUM", type_,
+ type_._enumerated_values)
def visit_SET(self, type_):
- return self._extend_string(type_, {}, "SET(%s)" %
- ",".join(type_._ddl_values))
+ return self._visit_enumerated_values("SET", type_,
+ type_._enumerated_values)
def visit_BOOLEAN(self, type):
return "BOOL"
@@ -2572,8 +2593,8 @@ class MySQLTableDefinitionParser(object):
if spec.get(kw, False):
type_kw[kw] = spec[kw]
- if type_ == 'enum':
- type_args = ENUM._strip_enums(type_args)
+ if issubclass(col_type, _EnumeratedValues):
+ type_args = _EnumeratedValues._strip_values(type_args)
type_instance = col_type(*type_args, **type_kw)
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index a87829499..90512e41a 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -11,7 +11,7 @@ from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
eq_, ne_, is_, is_not_, startswith_, assert_raises, \
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
- AssertsExecutionResults
+ AssertsExecutionResults, expect_deprecated
from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 96a8bc023..062fffb18 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -92,30 +92,36 @@ def uses_deprecated(*messages):
@decorator
def decorate(fn, *args, **kw):
- # todo: should probably be strict about this, too
- filters = [dict(action='ignore',
- category=sa_exc.SAPendingDeprecationWarning)]
- if not messages:
- filters.append(dict(action='ignore',
- category=sa_exc.SADeprecationWarning))
- else:
- filters.extend(
- [dict(action='ignore',
- message=message,
- category=sa_exc.SADeprecationWarning)
- for message in
- [(m.startswith('//') and
- ('Call to deprecated function ' + m[2:]) or m)
- for m in messages]])
-
- for f in filters:
- warnings.filterwarnings(**f)
- try:
+ with expect_deprecated(*messages):
return fn(*args, **kw)
- finally:
- resetwarnings()
return decorate
+@contextlib.contextmanager
+def expect_deprecated(*messages):
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SADeprecationWarning))
+ else:
+ filters.extend(
+ [dict(action='ignore',
+ message=message,
+ category=sa_exc.SADeprecationWarning)
+ for message in
+ [(m.startswith('//') and
+ ('Call to deprecated function ' + m[2:]) or m)
+ for m in messages]])
+
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ yield
+ finally:
+ resetwarnings()
+
+
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.