diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/enumerated.py')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/enumerated.py | 71 |
1 files changed, 36 insertions, 35 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py index f63d64e8f..9586eff3f 100644 --- a/lib/sqlalchemy/dialects/mysql/enumerated.py +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -14,29 +14,30 @@ from ...sql import sqltypes class _EnumeratedValues(_StringType): def _init_values(self, values, kw): - self.quoting = kw.pop('quoting', 'auto') + self.quoting = kw.pop("quoting", "auto") - if self.quoting == 'auto' and len(values): + if self.quoting == "auto" and len(values): # What quoting character are we using? q = None for e in values: if len(e) == 0: - self.quoting = 'unquoted' + self.quoting = "unquoted" break elif q is None: q = e[0] if len(e) == 1 or e[0] != q or e[-1] != q: - self.quoting = 'unquoted' + self.quoting = "unquoted" break else: - self.quoting = 'quoted' + self.quoting = "quoted" - if self.quoting == 'quoted': + if self.quoting == "quoted": util.warn_deprecated( - 'Manually quoting %s value literals is deprecated. Supply ' - 'unquoted values and use the quoting= option in cases of ' - 'ambiguity.' % self.__class__.__name__) + "Manually quoting %s value literals is deprecated. Supply " + "unquoted values and use the quoting= option in cases of " + "ambiguity." % self.__class__.__name__ + ) values = self._strip_values(values) @@ -58,7 +59,7 @@ class _EnumeratedValues(_StringType): class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): """MySQL ENUM type.""" - __visit_name__ = 'ENUM' + __visit_name__ = "ENUM" native_enum = True @@ -115,7 +116,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): """ - kw.pop('strict', None) + kw.pop("strict", None) self._enum_init(enums, kw) _StringType.__init__(self, length=self.length, **kw) @@ -145,13 +146,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): def __repr__(self): return util.generic_repr( - self, to_inspect=[ENUM, _StringType, sqltypes.Enum]) + self, to_inspect=[ENUM, _StringType, sqltypes.Enum] + ) class SET(_EnumeratedValues): """MySQL SET type.""" - __visit_name__ = 'SET' + __visit_name__ = "SET" def __init__(self, *values, **kw): """Construct a SET. @@ -216,45 +218,43 @@ class SET(_EnumeratedValues): """ - self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False) + self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False) values, length = self._init_values(values, kw) self.values = tuple(values) - if not self.retrieve_as_bitwise and '' in values: + if not self.retrieve_as_bitwise and "" in values: raise exc.ArgumentError( "Can't use the blank value '' in a SET without " - "setting retrieve_as_bitwise=True") + "setting retrieve_as_bitwise=True" + ) if self.retrieve_as_bitwise: self._bitmap = dict( - (value, 2 ** idx) - for idx, value in enumerate(self.values) + (value, 2 ** idx) for idx, value in enumerate(self.values) ) self._bitmap.update( - (2 ** idx, value) - for idx, value in enumerate(self.values) + (2 ** idx, value) for idx, value in enumerate(self.values) ) - kw.setdefault('length', length) + kw.setdefault("length", length) super(SET, self).__init__(**kw) def column_expression(self, colexpr): if self.retrieve_as_bitwise: return sql.type_coerce( - sql.type_coerce(colexpr, sqltypes.Integer) + 0, - self + sql.type_coerce(colexpr, sqltypes.Integer) + 0, self ) else: return colexpr def result_processor(self, dialect, coltype): if self.retrieve_as_bitwise: + def process(value): if value is not None: value = int(value) - return set( - util.map_bits(self._bitmap.__getitem__, value) - ) + return set(util.map_bits(self._bitmap.__getitem__, value)) else: return None + else: super_convert = super(SET, self).result_processor(dialect, coltype) @@ -263,18 +263,20 @@ class SET(_EnumeratedValues): # MySQLdb returns a string, let's parse if super_convert: value = super_convert(value) - return set(re.findall(r'[^,]+', value)) + return set(re.findall(r"[^,]+", value)) else: # mysql-connector-python does a naive # split(",") which throws in an empty string if value is not None: - value.discard('') + value.discard("") return value + return process def bind_processor(self, dialect): super_convert = super(SET, self).bind_processor(dialect) if self.retrieve_as_bitwise: + def process(value): if value is None: return None @@ -288,24 +290,23 @@ class SET(_EnumeratedValues): for v in value: int_value |= self._bitmap[v] return int_value + else: def process(value): # accept strings and int (actually bitflag) values directly if value is not None and not isinstance( - value, util.int_types + util.string_types): + value, util.int_types + util.string_types + ): value = ",".join(value) if super_convert: return super_convert(value) else: return value + return process def adapt(self, impltype, **kw): - kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise - return util.constructor_copy( - self, impltype, - *self.values, - **kw - ) + kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise + return util.constructor_copy(self, impltype, *self.values, **kw) |