diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
---|---|---|
committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/dialects/mysql/base.py | |
parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 1397 |
1 files changed, 896 insertions, 501 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 673d4b9ff..7b0d0618c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -746,85 +746,340 @@ from ...engine import reflection from ...engine import default from ... import types as sqltypes from ...util import topological -from ...types import DATE, BOOLEAN, \ - BLOB, BINARY, VARBINARY +from ...types import DATE, BOOLEAN, BLOB, BINARY, VARBINARY from . import reflection as _reflection -from .types import BIGINT, BIT, CHAR, DECIMAL, DATETIME, \ - DOUBLE, FLOAT, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, \ - MEDIUMTEXT, NCHAR, NUMERIC, NVARCHAR, REAL, SMALLINT, TEXT, TIME, \ - TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT, VARCHAR, YEAR -from .types import _StringType, _IntegerType, _NumericType, \ - _FloatType, _MatchType +from .types import ( + BIGINT, + BIT, + CHAR, + DECIMAL, + DATETIME, + DOUBLE, + FLOAT, + INTEGER, + LONGBLOB, + LONGTEXT, + MEDIUMBLOB, + MEDIUMINT, + MEDIUMTEXT, + NCHAR, + NUMERIC, + NVARCHAR, + REAL, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + TINYBLOB, + TINYINT, + TINYTEXT, + VARCHAR, + YEAR, +) +from .types import ( + _StringType, + _IntegerType, + _NumericType, + _FloatType, + _MatchType, +) from .enumerated import ENUM, SET from .json import JSON, JSONIndexType, JSONPathType RESERVED_WORDS = set( - ['accessible', 'add', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', - 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', - 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', - 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', - 'current_user', 'cursor', 'database', 'databases', 'day_hour', - 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', - 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', - 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', - 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', - 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', - 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', - 'having', 'high_priority', 'hour_microsecond', 'hour_minute', - 'hour_second', 'if', 'ignore', 'in', 'index', 'infile', 'inner', 'inout', - 'insensitive', 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8', - 'integer', 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys', - 'kill', 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines', - 'load', 'localtime', 'localtimestamp', 'lock', 'long', 'longblob', - 'longtext', 'loop', 'low_priority', 'master_ssl_verify_server_cert', - 'match', 'mediumblob', 'mediumint', 'mediumtext', 'middleint', - 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural', - 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize', - 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile', - 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads', - 'read_only', 'read_write', 'real', 'references', 'regexp', 'release', - 'rename', 'repeat', 'replace', 'require', 'restrict', 'return', - 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond', - 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial', - 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning', - 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl', - 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob', - 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo', - 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', - 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', - 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', - - 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 - - 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1 - - 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', - 'read_only', 'read_write', # 5.1 - - 'general', 'ignore_server_ids', 'master_heartbeat_period', 'maxvalue', - 'resignal', 'signal', 'slow', # 5.5 - - 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot', - 'partition', 'sql_after_gtids', 'sql_before_gtids', # 5.6 - - 'generated', 'optimizer_costs', 'stored', 'virtual', # 5.7 - - 'admin', 'cume_dist', 'empty', 'except', 'first_value', 'grouping', - 'function', 'groups', 'json_table', 'last_value', 'nth_value', - 'ntile', 'of', 'over', 'percent_rank', 'persist', 'persist_only', - 'rank', 'recursive', 'role', 'row', 'rows', 'row_number', 'system', - 'window', # 8.0 - ]) + [ + "accessible", + "add", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "enclosed", + "escaped", + "exists", + "exit", + "explain", + "false", + "fetch", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "grant", + "group", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "interval", + "into", + "is", + "iterate", + "join", + "key", + "keys", + "kill", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_ssl_verify_server_cert", + "match", + "mediumblob", + "mediumint", + "mediumtext", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "not", + "no_write_to_binlog", + "null", + "numeric", + "on", + "optimize", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "precision", + "primary", + "procedure", + "purge", + "range", + "read", + "reads", + "read_only", + "read_write", + "real", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "restrict", + "return", + "revoke", + "right", + "rlike", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "smallint", + "spatial", + "specific", + "sql", + "sqlexception", + "sqlstate", + "sqlwarning", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "ssl", + "starting", + "straight_join", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "when", + "where", + "while", + "with", + "write", + "x509", + "xor", + "year_month", + "zerofill", # 5.0 + "columns", + "fields", + "privileges", + "soname", + "tables", # 4.1 + "accessible", + "linear", + "master_ssl_verify_server_cert", + "range", + "read_only", + "read_write", # 5.1 + "general", + "ignore_server_ids", + "master_heartbeat_period", + "maxvalue", + "resignal", + "signal", + "slow", # 5.5 + "get", + "io_after_gtids", + "io_before_gtids", + "master_bind", + "one_shot", + "partition", + "sql_after_gtids", + "sql_before_gtids", # 5.6 + "generated", + "optimizer_costs", + "stored", + "virtual", # 5.7 + "admin", + "cume_dist", + "empty", + "except", + "first_value", + "grouping", + "function", + "groups", + "json_table", + "last_value", + "nth_value", + "ntile", + "of", + "over", + "percent_rank", + "persist", + "persist_only", + "rank", + "recursive", + "role", + "row", + "rows", + "row_number", + "system", + "window", # 8.0 + ] +) AUTOCOMMIT_RE = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)", + re.I | re.UNICODE, +) SET_RE = re.compile( - r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', - re.I | re.UNICODE) + r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE +) # old names @@ -870,52 +1125,50 @@ colspecs = { sqltypes.MatchType: _MatchType, sqltypes.JSON: JSON, sqltypes.JSON.JSONIndexType: JSONIndexType, - sqltypes.JSON.JSONPathType: JSONPathType - + sqltypes.JSON.JSONPathType: JSONPathType, } # Everything 3.23 through 5.1 excepting OpenGIS types. ischema_names = { - 'bigint': BIGINT, - 'binary': BINARY, - 'bit': BIT, - 'blob': BLOB, - 'boolean': BOOLEAN, - 'char': CHAR, - 'date': DATE, - 'datetime': DATETIME, - 'decimal': DECIMAL, - 'double': DOUBLE, - 'enum': ENUM, - 'fixed': DECIMAL, - 'float': FLOAT, - 'int': INTEGER, - 'integer': INTEGER, - 'json': JSON, - 'longblob': LONGBLOB, - 'longtext': LONGTEXT, - 'mediumblob': MEDIUMBLOB, - 'mediumint': MEDIUMINT, - 'mediumtext': MEDIUMTEXT, - 'nchar': NCHAR, - 'nvarchar': NVARCHAR, - 'numeric': NUMERIC, - 'set': SET, - 'smallint': SMALLINT, - 'text': TEXT, - 'time': TIME, - 'timestamp': TIMESTAMP, - 'tinyblob': TINYBLOB, - 'tinyint': TINYINT, - 'tinytext': TINYTEXT, - 'varbinary': VARBINARY, - 'varchar': VARCHAR, - 'year': YEAR, + "bigint": BIGINT, + "binary": BINARY, + "bit": BIT, + "blob": BLOB, + "boolean": BOOLEAN, + "char": CHAR, + "date": DATE, + "datetime": DATETIME, + "decimal": DECIMAL, + "double": DOUBLE, + "enum": ENUM, + "fixed": DECIMAL, + "float": FLOAT, + "int": INTEGER, + "integer": INTEGER, + "json": JSON, + "longblob": LONGBLOB, + "longtext": LONGTEXT, + "mediumblob": MEDIUMBLOB, + "mediumint": MEDIUMINT, + "mediumtext": MEDIUMTEXT, + "nchar": NCHAR, + "nvarchar": NVARCHAR, + "numeric": NUMERIC, + "set": SET, + "smallint": SMALLINT, + "text": TEXT, + "time": TIME, + "timestamp": TIMESTAMP, + "tinyblob": TINYBLOB, + "tinyint": TINYINT, + "tinytext": TINYTEXT, + "varbinary": VARBINARY, + "varchar": VARCHAR, + "year": YEAR, } class MySQLExecutionContext(default.DefaultExecutionContext): - def should_autocommit_text(self, statement): return AUTOCOMMIT_RE.match(statement) @@ -932,7 +1185,7 @@ class MySQLCompiler(compiler.SQLCompiler): """Overridden from base SQLCompiler value""" extract_map = compiler.SQLCompiler.extract_map.copy() - extract_map.update({'milliseconds': 'millisecond'}) + extract_map.update({"milliseconds": "millisecond"}) def visit_random_func(self, fn, **kw): return "rand%s" % self.function_argspec(fn) @@ -943,12 +1196,14 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_on_duplicate_key_update(self, on_duplicate, **kw): if on_duplicate._parameter_ordering: @@ -958,7 +1213,8 @@ class MySQLCompiler(compiler.SQLCompiler): ] ordered_keys = set(parameter_ordering) cols = [ - self.statement.table.c[key] for key in parameter_ordering + self.statement.table.c[key] + for key in parameter_ordering if key in self.statement.table.c ] + [ c for c in self.statement.table.c if c.key not in ordered_keys @@ -979,9 +1235,11 @@ class MySQLCompiler(compiler.SQLCompiler): val = val._clone() val.type = column.type value_text = self.process(val.self_group(), use_schema=False) - elif isinstance(val, elements.ColumnClause) \ - and val.table is on_duplicate.inserted_alias: - value_text = 'VALUES(' + self.preparer.quote(column.name) + ')' + elif ( + isinstance(val, elements.ColumnClause) + and val.table is on_duplicate.inserted_alias + ): + value_text = "VALUES(" + self.preparer.quote(column.name) + ")" else: value_text = self.process(val.self_group(), use_schema=False) name_text = self.preparer.quote(column.name) @@ -990,22 +1248,27 @@ class MySQLCompiler(compiler.SQLCompiler): non_matching = set(on_duplicate.update) - set(c.key for c in cols) if non_matching: util.warn( - 'Additional column names not matching ' - "any column keys in table '%s': %s" % ( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( self.statement.table.name, - (', '.join("'%s'" % c for c in non_matching)) + (", ".join("'%s'" % c for c in non_matching)), ) ) - return 'ON DUPLICATE KEY UPDATE ' + ', '.join(clauses) + return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses) def visit_concat_op_binary(self, binary, operator, **kw): - return "concat(%s, %s)" % (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "concat(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_match_op_binary(self, binary, operator, **kw): - return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \ - (self.process(binary.left, **kw), self.process(binary.right, **kw)) + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def get_from_hint_text(self, table, text): return text @@ -1016,26 +1279,35 @@ class MySQLCompiler(compiler.SQLCompiler): if isinstance(type_, sqltypes.TypeDecorator): return self.visit_typeclause(typeclause, type_.impl, **kw) elif isinstance(type_, sqltypes.Integer): - if getattr(type_, 'unsigned', False): - return 'UNSIGNED INTEGER' + if getattr(type_, "unsigned", False): + return "UNSIGNED INTEGER" else: - return 'SIGNED INTEGER' + return "SIGNED INTEGER" elif isinstance(type_, sqltypes.TIMESTAMP): - return 'DATETIME' - elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, - sqltypes.Date, sqltypes.Time)): + return "DATETIME" + elif isinstance( + type_, + ( + sqltypes.DECIMAL, + sqltypes.DateTime, + sqltypes.Date, + sqltypes.Time, + ), + ): return self.dialect.type_compiler.process(type_) - elif isinstance(type_, sqltypes.String) \ - and not isinstance(type_, (ENUM, SET)): + elif isinstance(type_, sqltypes.String) and not isinstance( + type_, (ENUM, SET) + ): adapted = CHAR._adapt_string_for_cast(type_) return self.dialect.type_compiler.process(adapted) elif isinstance(type_, sqltypes._Binary): - return 'BINARY' + return "BINARY" elif isinstance(type_, sqltypes.JSON): return "JSON" elif isinstance(type_, sqltypes.NUMERIC): - return self.dialect.type_compiler.process( - type_).replace('NUMERIC', 'DECIMAL') + return self.dialect.type_compiler.process(type_).replace( + "NUMERIC", "DECIMAL" + ) else: return None @@ -1044,23 +1316,25 @@ class MySQLCompiler(compiler.SQLCompiler): if not self.dialect._supports_cast: util.warn( "Current MySQL version does not support " - "CAST; the CAST will be skipped.") + "CAST; the CAST will be skipped." + ) return self.process(cast.clause.self_group(), **kw) type_ = self.process(cast.typeclause) if type_ is None: util.warn( "Datatype %s does not support CAST on MySQL; " - "the CAST will be skipped." % - self.dialect.type_compiler.process(cast.typeclause.type)) + "the CAST will be skipped." + % self.dialect.type_compiler.process(cast.typeclause.type) + ) return self.process(cast.clause.self_group(), **kw) - return 'CAST(%s AS %s)' % (self.process(cast.clause, **kw), type_) + return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) def render_literal_value(self, value, type_): value = super(MySQLCompiler, self).render_literal_value(value, type_) if self.dialect._backslash_escapes: - value = value.replace('\\', '\\\\') + value = value.replace("\\", "\\\\") return value # override native_boolean=False behavior here, as @@ -1096,12 +1370,15 @@ class MySQLCompiler(compiler.SQLCompiler): else: join_type = " INNER JOIN " - return ''.join( - (self.process(join.left, asfrom=True, **kwargs), - join_type, - self.process(join.right, asfrom=True, **kwargs), - " ON ", - self.process(join.onclause, **kwargs))) + return "".join( + ( + self.process(join.left, asfrom=True, **kwargs), + join_type, + self.process(join.right, asfrom=True, **kwargs), + " ON ", + self.process(join.onclause, **kwargs), + ) + ) def for_update_clause(self, select, **kw): if select._for_update_arg.read: @@ -1118,11 +1395,13 @@ class MySQLCompiler(compiler.SQLCompiler): # The latter is more readable for offsets but we're stuck with the # former until we can refine dialects by server revision. - limit_clause, offset_clause = select._limit_clause, \ - select._offset_clause + limit_clause, offset_clause = ( + select._limit_clause, + select._offset_clause, + ) if limit_clause is None and offset_clause is None: - return '' + return "" elif offset_clause is not None: # As suggested by the MySQL docs, need to apply an # artificial limit if one wasn't provided @@ -1134,35 +1413,38 @@ class MySQLCompiler(compiler.SQLCompiler): # but also is consistent with the usage of the upper # bound as part of MySQL's "syntax" for OFFSET with # no LIMIT - return ' \n LIMIT %s, %s' % ( + return " \n LIMIT %s, %s" % ( self.process(offset_clause, **kw), - "18446744073709551615") + "18446744073709551615", + ) else: - return ' \n LIMIT %s, %s' % ( + return " \n LIMIT %s, %s" % ( self.process(offset_clause, **kw), - self.process(limit_clause, **kw)) + self.process(limit_clause, **kw), + ) else: # No offset provided, so just use the limit - return ' \n LIMIT %s' % (self.process(limit_clause, **kw),) + return " \n LIMIT %s" % (self.process(limit_clause, **kw),) def update_limit_clause(self, update_stmt): - limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None) + limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) if limit: return "LIMIT %s" % limit else: return None - def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): - return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) - for t in [from_table] + list(extra_froms)) + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + return ", ".join( + t._compiler_dispatch(self, asfrom=True, **kw) + for t in [from_table] + list(extra_froms) + ) - def update_from_clause(self, update_stmt, from_table, - extra_froms, from_hints, **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): return None - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1171,24 +1453,27 @@ class MySQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. USING clause specific to MySQL.""" - return "USING " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "USING " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) def visit_empty_set_expr(self, element_types): return ( "SELECT %(outer)s FROM (SELECT %(inner)s) " - "as _empty_set WHERE 1!=1" % { + "as _empty_set WHERE 1!=1" + % { "inner": ", ".join( "1 AS _in_%s" % idx - for idx, type_ in enumerate(element_types)), + for idx, type_ in enumerate(element_types) + ), "outer": ", ".join( - "_in_%s" % idx - for idx, type_ in enumerate(element_types)) + "_in_%s" % idx for idx, type_ in enumerate(element_types) + ), } ) @@ -1200,35 +1485,39 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec = [ self.preparer.format_column(column), self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ), ] is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP) if not column.nullable: - colspec.append('NOT NULL') + colspec.append("NOT NULL") # see: http://docs.sqlalchemy.org/en/latest/dialects/ # mysql.html#mysql_timestamp_null elif column.nullable and is_timestamp: - colspec.append('NULL') + colspec.append("NULL") default = self.get_column_default_string(column) if default is not None: - colspec.append('DEFAULT ' + default) + colspec.append("DEFAULT " + default) comment = column.comment if comment is not None: literal = self.sql_compiler.render_literal_value( - comment, sqltypes.String()) - colspec.append('COMMENT ' + literal) + comment, sqltypes.String() + ) + colspec.append("COMMENT " + literal) - if column.table is not None \ - and column is column.table._autoincrement_column and \ - column.server_default is None: - colspec.append('AUTO_INCREMENT') + if ( + column.table is not None + and column is column.table._autoincrement_column + and column.server_default is None + ): + colspec.append("AUTO_INCREMENT") - return ' '.join(colspec) + return " ".join(colspec) def post_create_table(self, table): """Build table-level CREATE options like ENGINE and COLLATE.""" @@ -1236,76 +1525,94 @@ class MySQLDDLCompiler(compiler.DDLCompiler): table_opts = [] opts = dict( - ( - k[len(self.dialect.name) + 1:].upper(), - v - ) + (k[len(self.dialect.name) + 1 :].upper(), v) for k, v in table.kwargs.items() - if k.startswith('%s_' % self.dialect.name) + if k.startswith("%s_" % self.dialect.name) ) if table.comment is not None: - opts['COMMENT'] = table.comment + opts["COMMENT"] = table.comment partition_options = [ - 'PARTITION_BY', 'PARTITIONS', 'SUBPARTITIONS', - 'SUBPARTITION_BY' + "PARTITION_BY", + "PARTITIONS", + "SUBPARTITIONS", + "SUBPARTITION_BY", ] nonpart_options = set(opts).difference(partition_options) part_options = set(opts).intersection(partition_options) - for opt in topological.sort([ - ('DEFAULT_CHARSET', 'COLLATE'), - ('DEFAULT_CHARACTER_SET', 'COLLATE'), - ], nonpart_options): + for opt in topological.sort( + [ + ("DEFAULT_CHARSET", "COLLATE"), + ("DEFAULT_CHARACTER_SET", "COLLATE"), + ], + nonpart_options, + ): arg = opts[opt] if opt in _reflection._options_of_type_string: arg = self.sql_compiler.render_literal_value( - arg, sqltypes.String()) - - if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY', - 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET', - 'DEFAULT_CHARSET', - 'DEFAULT_COLLATE'): - opt = opt.replace('_', ' ') + arg, sqltypes.String() + ) - joiner = '=' - if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', - 'CHARACTER SET', 'COLLATE'): - joiner = ' ' + if opt in ( + "DATA_DIRECTORY", + "INDEX_DIRECTORY", + "DEFAULT_CHARACTER_SET", + "CHARACTER_SET", + "DEFAULT_CHARSET", + "DEFAULT_COLLATE", + ): + opt = opt.replace("_", " ") + + joiner = "=" + if opt in ( + "TABLESPACE", + "DEFAULT CHARACTER SET", + "CHARACTER SET", + "COLLATE", + ): + joiner = " " table_opts.append(joiner.join((opt, arg))) - for opt in topological.sort([ - ('PARTITION_BY', 'PARTITIONS'), - ('PARTITION_BY', 'SUBPARTITION_BY'), - ('PARTITION_BY', 'SUBPARTITIONS'), - ('PARTITIONS', 'SUBPARTITIONS'), - ('PARTITIONS', 'SUBPARTITION_BY'), - ('SUBPARTITION_BY', 'SUBPARTITIONS') - ], part_options): + for opt in topological.sort( + [ + ("PARTITION_BY", "PARTITIONS"), + ("PARTITION_BY", "SUBPARTITION_BY"), + ("PARTITION_BY", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITION_BY"), + ("SUBPARTITION_BY", "SUBPARTITIONS"), + ], + part_options, + ): arg = opts[opt] if opt in _reflection._options_of_type_string: arg = self.sql_compiler.render_literal_value( - arg, sqltypes.String()) + arg, sqltypes.String() + ) - opt = opt.replace('_', ' ') - joiner = ' ' + opt = opt.replace("_", " ") + joiner = " " table_opts.append(joiner.join((opt, arg))) - return ' '.join(table_opts) + return " ".join(table_opts) def visit_create_index(self, create, **kw): index = create.element self._verify_index_table(index) preparer = self.preparer table = preparer.format_table(index.table) - columns = [self.sql_compiler.process(expr, include_table=False, - literal_binds=True) - for expr in index.expressions] + columns = [ + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ] name = self._prepared_index_name(index) @@ -1313,53 +1620,54 @@ class MySQLDDLCompiler(compiler.DDLCompiler): if index.unique: text += "UNIQUE " - index_prefix = index.kwargs.get('mysql_prefix', None) + index_prefix = index.kwargs.get("mysql_prefix", None) if index_prefix: - text += index_prefix + ' ' + text += index_prefix + " " text += "INDEX %s ON %s " % (name, table) - length = index.dialect_options['mysql']['length'] + length = index.dialect_options["mysql"]["length"] if length is not None: if isinstance(length, dict): # length value can be a (column_name --> integer value) # mapping specifying the prefix length for each column of the # index - columns = ', '.join( - '%s(%d)' % (expr, length[col.name]) if col.name in length - else - ( - '%s(%d)' % (expr, length[expr]) if expr in length - else '%s' % expr + columns = ", ".join( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr ) for col, expr in zip(index.expressions, columns) ) else: # or can be an integer value specifying the same # prefix length for all columns of the index - columns = ', '.join( - '%s(%d)' % (col, length) - for col in columns + columns = ", ".join( + "%s(%d)" % (col, length) for col in columns ) else: - columns = ', '.join(columns) - text += '(%s)' % columns + columns = ", ".join(columns) + text += "(%s)" % columns - parser = index.dialect_options['mysql']['with_parser'] + parser = index.dialect_options["mysql"]["with_parser"] if parser is not None: - text += " WITH PARSER %s" % (parser, ) + text += " WITH PARSER %s" % (parser,) - using = index.dialect_options['mysql']['using'] + using = index.dialect_options["mysql"]["using"] if using is not None: text += " USING %s" % (preparer.quote(using)) return text def visit_primary_key_constraint(self, constraint): - text = super(MySQLDDLCompiler, self).\ - visit_primary_key_constraint(constraint) - using = constraint.dialect_options['mysql']['using'] + text = super(MySQLDDLCompiler, self).visit_primary_key_constraint( + constraint + ) + using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) return text @@ -1368,9 +1676,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler): index = drop.element return "\nDROP INDEX %s ON %s" % ( - self._prepared_index_name(index, - include_schema=False), - self.preparer.format_table(index.table)) + self._prepared_index_name(index, include_schema=False), + self.preparer.format_table(index.table), + ) def visit_drop_constraint(self, drop): constraint = drop.element @@ -1386,29 +1694,33 @@ class MySQLDDLCompiler(compiler.DDLCompiler): else: qual = "" const = self.preparer.format_constraint(constraint) - return "ALTER TABLE %s DROP %s%s" % \ - (self.preparer.format_table(constraint.table), - qual, const) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) def define_constraint_match(self, constraint): if constraint.match is not None: raise exc.CompileError( "MySQL ignores the 'MATCH' keyword while at the same time " - "causes ON UPDATE/ON DELETE clauses to be ignored.") + "causes ON UPDATE/ON DELETE clauses to be ignored." + ) return "" def visit_set_table_comment(self, create): return "ALTER TABLE %s COMMENT %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_set_column_comment(self, create): return "ALTER TABLE %s CHANGE %s %s" % ( self.preparer.format_table(create.element.table), self.preparer.format_column(create.element), - self.get_column_specification(create.element) + self.get_column_specification(create.element), ) @@ -1420,9 +1732,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return spec if type_.unsigned: - spec += ' UNSIGNED' + spec += " UNSIGNED" if type_.zerofill: - spec += ' ZEROFILL' + spec += " ZEROFILL" return spec def _extend_string(self, type_, defaults, spec): @@ -1434,28 +1746,30 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def attr(name): return getattr(type_, name, defaults.get(name)) - if attr('charset'): - charset = 'CHARACTER SET %s' % attr('charset') - elif attr('ascii'): - charset = 'ASCII' - elif attr('unicode'): - charset = 'UNICODE' + if attr("charset"): + charset = "CHARACTER SET %s" % attr("charset") + elif attr("ascii"): + charset = "ASCII" + elif attr("unicode"): + charset = "UNICODE" else: charset = None - if attr('collation'): - collation = 'COLLATE %s' % type_.collation - elif attr('binary'): - collation = 'BINARY' + if attr("collation"): + collation = "COLLATE %s" % type_.collation + elif attr("binary"): + collation = "BINARY" else: collation = None - if attr('national'): + if attr("national"): # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. - return ' '.join([c for c in ('NATIONAL', spec, collation) - if c is not None]) - return ' '.join([c for c in (spec, charset, collation) - if c is not None]) + return " ".join( + [c for c in ("NATIONAL", spec, collation) if c is not None] + ) + return " ".join( + [c for c in (spec, charset, collation) if c is not None] + ) def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) @@ -1464,95 +1778,113 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: - return self._extend_numeric(type_, - "NUMERIC(%(precision)s)" % - {'precision': type_.precision}) + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s)" % {"precision": type_.precision}, + ) else: - return self._extend_numeric(type_, - "NUMERIC(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: - return self._extend_numeric(type_, - "DECIMAL(%(precision)s)" % - {'precision': type_.precision}) + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s)" % {"precision": type_.precision}, + ) else: - return self._extend_numeric(type_, - "DECIMAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: - return self._extend_numeric(type_, - "DOUBLE(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "DOUBLE(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) else: - return self._extend_numeric(type_, 'DOUBLE') + return self._extend_numeric(type_, "DOUBLE") def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: - return self._extend_numeric(type_, - "REAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "REAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) else: - return self._extend_numeric(type_, 'REAL') + return self._extend_numeric(type_, "REAL") def visit_FLOAT(self, type_, **kw): - if self._mysql_type(type_) and \ - type_.scale is not None and \ - type_.precision is not None: + if ( + self._mysql_type(type_) + and type_.scale is not None + and type_.precision is not None + ): return self._extend_numeric( - type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale) + ) elif type_.precision is not None: - return self._extend_numeric(type_, - "FLOAT(%s)" % (type_.precision,)) + return self._extend_numeric( + type_, "FLOAT(%s)" % (type_.precision,) + ) else: return self._extend_numeric(type_, "FLOAT") def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "INTEGER(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "INTEGER(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "INTEGER") def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "BIGINT(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "BIGINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "BIGINT") def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "MEDIUMINT(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "MEDIUMINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "MEDIUMINT") def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: - return self._extend_numeric(type_, - "TINYINT(%s)" % type_.display_width) + return self._extend_numeric( + type_, "TINYINT(%s)" % type_.display_width + ) else: return self._extend_numeric(type_, "TINYINT") def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: - return self._extend_numeric(type_, - "SMALLINT(%(display_width)s)" % - {'display_width': type_.display_width} - ) + return self._extend_numeric( + type_, + "SMALLINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "SMALLINT") @@ -1563,7 +1895,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "BIT" def visit_DATETIME(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "DATETIME(%d)" % type_.fsp else: return "DATETIME" @@ -1572,13 +1904,13 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "DATE" def visit_TIME(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "TIME(%d)" % type_.fsp else: return "TIME" def visit_TIMESTAMP(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" @@ -1606,17 +1938,17 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_VARCHAR(self, type_, **kw): if type_.length: - return self._extend_string( - type_, {}, "VARCHAR(%d)" % type_.length) + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) def visit_CHAR(self, type_, **kw): if type_.length: - return self._extend_string(type_, {}, "CHAR(%(length)s)" % - {'length': type_.length}) + return self._extend_string( + type_, {}, "CHAR(%(length)s)" % {"length": type_.length} + ) else: return self._extend_string(type_, {}, "CHAR") @@ -1625,22 +1957,26 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): # of "NVARCHAR". if type_.length: return self._extend_string( - type_, {'national': True}, - "VARCHAR(%(length)s)" % {'length': type_.length}) + type_, + {"national": True}, + "VARCHAR(%(length)s)" % {"length": type_.length}, + ) else: raise exc.CompileError( - "NVARCHAR requires a length on dialect %s" % - self.dialect.name) + "NVARCHAR requires a length on dialect %s" % self.dialect.name + ) def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: return self._extend_string( - type_, {'national': True}, - "CHAR(%(length)s)" % {'length': type_.length}) + type_, + {"national": True}, + "CHAR(%(length)s)" % {"length": type_.length}, + ) else: - return self._extend_string(type_, {'national': True}, "CHAR") + return self._extend_string(type_, {"national": True}, "CHAR") def visit_VARBINARY(self, type_, **kw): return "VARBINARY(%d)" % type_.length @@ -1676,17 +2012,19 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): quoted_enums = [] for e in enumerated_values: quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend_string(type_, {}, "%s(%s)" % ( - name, ",".join(quoted_enums)) + return self._extend_string( + type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) ) def visit_ENUM(self, type_, **kw): - return self._visit_enumerated_values("ENUM", type_, - type_._enumerated_values) + return self._visit_enumerated_values( + "ENUM", type_, type_._enumerated_values + ) def visit_SET(self, type_, **kw): - return self._visit_enumerated_values("SET", type_, - type_._enumerated_values) + return self._visit_enumerated_values( + "SET", type_, type_._enumerated_values + ) def visit_BOOLEAN(self, type, **kw): return "BOOL" @@ -1703,9 +2041,8 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): quote = '"' super(MySQLIdentifierPreparer, self).__init__( - dialect, - initial_quote=quote, - escape_quote=quote) + dialect, initial_quote=quote, escape_quote=quote + ) def _quote_free_identifiers(self, *ids): """Unilaterally identifier-quote any number of strings.""" @@ -1719,7 +2056,7 @@ class MySQLDialect(default.DefaultDialect): Not used directly in application code. """ - name = 'mysql' + name = "mysql" supports_alter = True # MySQL has no true "boolean" type; we @@ -1738,7 +2075,7 @@ class MySQLDialect(default.DefaultDialect): supports_comments = True inline_comments = True - default_paramstyle = 'format' + default_paramstyle = "format" colspecs = colspecs cte_follows_insert = True @@ -1756,26 +2093,28 @@ class MySQLDialect(default.DefaultDialect): _server_ansiquotes = False construct_arguments = [ - (sa_schema.Table, { - "*": None - }), - (sql.Update, { - "limit": None - }), - (sa_schema.PrimaryKeyConstraint, { - "using": None - }), - (sa_schema.Index, { - "using": None, - "length": None, - "prefix": None, - "with_parser": None - }) + (sa_schema.Table, {"*": None}), + (sql.Update, {"limit": None}), + (sa_schema.PrimaryKeyConstraint, {"using": None}), + ( + sa_schema.Index, + { + "using": None, + "length": None, + "prefix": None, + "with_parser": None, + }, + ), ] - def __init__(self, isolation_level=None, json_serializer=None, - json_deserializer=None, **kwargs): - kwargs.pop('use_ansiquotes', None) # legacy + def __init__( + self, + isolation_level=None, + json_serializer=None, + json_deserializer=None, + **kwargs + ): + kwargs.pop("use_ansiquotes", None) # legacy default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_serializer = json_serializer @@ -1783,22 +2122,30 @@ class MySQLDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" # to work properly - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection self._set_isolation_level(connection, level) @@ -1807,8 +2154,8 @@ class MySQLDialect(default.DefaultDialect): if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) @@ -1818,9 +2165,9 @@ class MySQLDialect(default.DefaultDialect): def get_isolation_level(self, connection): cursor = connection.cursor() if self._is_mysql and self.server_version_info >= (5, 7, 20): - cursor.execute('SELECT @@transaction_isolation') + cursor.execute("SELECT @@transaction_isolation") else: - cursor.execute('SELECT @@tx_isolation') + cursor.execute("SELECT @@tx_isolation") val = cursor.fetchone()[0] cursor.close() if util.py3k and isinstance(val, bytes): @@ -1840,7 +2187,7 @@ class MySQLDialect(default.DefaultDialect): val = val.decode() version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(val): try: version.append(int(n)) @@ -1885,29 +2232,38 @@ class MySQLDialect(default.DefaultDialect): connection.execute(sql.text("XA END :xid"), xid=xid) connection.execute(sql.text("XA PREPARE :xid"), xid=xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: connection.execute(sql.text("XA END :xid"), xid=xid) connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid) - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_prepare_twophase(connection, xid) connection.execute(sql.text("XA COMMIT :xid"), xid=xid) def do_recover_twophase(self, connection): resultset = connection.execute("XA RECOVER") - return [row['data'][0:row['gtrid_length']] for row in resultset] + return [row["data"][0 : row["gtrid_length"]] for row in resultset] def is_disconnect(self, e, connection, cursor): - if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): - return self._extract_error_code(e) in \ - (2006, 2013, 2014, 2045, 2055) + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) elif isinstance( - e, (self.dbapi.InterfaceError, self.dbapi.InternalError)): + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) + ): # if underlying connection is closed, # this is the error you get return "(0, '')" in str(e) @@ -1944,7 +2300,7 @@ class MySQLDialect(default.DefaultDialect): raise NotImplementedError() def _get_default_schema_name(self, connection): - return connection.execute('SELECT DATABASE()').scalar() + return connection.execute("SELECT DATABASE()").scalar() def has_table(self, connection, table_name, schema=None): # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly @@ -1957,15 +2313,19 @@ class MySQLDialect(default.DefaultDialect): # full_name = self.identifier_preparer.format_table(table, # use_schema=True) - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) st = "DESCRIBE %s" % full_name rs = None try: try: rs = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) have = rs.fetchone() is not None rs.close() return have @@ -1986,12 +2346,13 @@ class MySQLDialect(default.DefaultDialect): # if ansiquotes == True, build a new IdentifierPreparer # with the new setting self.identifier_preparer = self.preparer( - self, server_ansiquotes=self._server_ansiquotes) + self, server_ansiquotes=self._server_ansiquotes + ) default.DefaultDialect.initialize(self, connection) self._needs_correct_for_88718 = ( - not self._is_mariadb and self.server_version_info >= (8, ) + not self._is_mariadb and self.server_version_info >= (8,) ) self._warn_for_known_db_issues() @@ -2007,20 +2368,23 @@ class MySQLDialect(default.DefaultDialect): "additional issue prevents proper migrations of columns " "with CHECK constraints (MDEV-11114). Please upgrade to " "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " - "series, to avoid these issues." % (mdb_version, )) + "series, to avoid these issues." % (mdb_version,) + ) @property def _is_mariadb(self): - return 'MariaDB' in self.server_version_info + return "MariaDB" in self.server_version_info @property def _is_mysql(self): - return 'MariaDB' not in self.server_version_info + return "MariaDB" not in self.server_version_info @property def _is_mariadb_102(self): - return self._is_mariadb and \ - self._mariadb_normalized_version_info > (10, 2) + return self._is_mariadb and self._mariadb_normalized_version_info > ( + 10, + 2, + ) @property def _mariadb_normalized_version_info(self): @@ -2028,15 +2392,17 @@ class MySQLDialect(default.DefaultDialect): # the string "5.5"; now that we use @@version we no longer see this. if self._is_mariadb: - idx = self.server_version_info.index('MariaDB') - return self.server_version_info[idx - 3: idx] + idx = self.server_version_info.index("MariaDB") + return self.server_version_info[idx - 3 : idx] else: return self.server_version_info @property def _supports_cast(self): - return self.server_version_info is None or \ - self.server_version_info >= (4, 0, 2) + return ( + self.server_version_info is None + or self.server_version_info >= (4, 0, 2) + ) @reflection.cache def get_schema_names(self, connection, **kw): @@ -2054,18 +2420,23 @@ class MySQLDialect(default.DefaultDialect): charset = self._connection_charset if self.server_version_info < (5, 0, 2): rp = connection.execute( - "SHOW TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) - return [row[0] for - row in self._compat_fetchall(rp, charset=charset)] + "SHOW TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) + return [ + row[0] for row in self._compat_fetchall(rp, charset=charset) + ] else: rp = connection.execute( - "SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) - return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] == 'BASE TABLE'] + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == "BASE TABLE" + ] @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -2077,72 +2448,77 @@ class MySQLDialect(default.DefaultDialect): return self.get_table_names(connection, schema) charset = self._connection_charset rp = connection.execute( - "SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(schema)) - return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] in ('VIEW', 'SYSTEM VIEW')] + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(schema) + ) + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ("VIEW", "SYSTEM VIEW") + ] @reflection.cache def get_table_options(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return parsed_state.table_options @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return parsed_state.columns @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) for key in parsed_state.keys: - if key['type'] == 'PRIMARY': + if key["type"] == "PRIMARY": # There can be only one. - cols = [s[0] for s in key['columns']] - return {'constrained_columns': cols, 'name': None} - return {'constrained_columns': [], 'name': None} + cols = [s[0] for s in key["columns"]] + return {"constrained_columns": cols, "name": None} + return {"constrained_columns": [], "name": None} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) default_schema = None fkeys = [] for spec in parsed_state.fk_constraints: - ref_name = spec['table'][-1] - ref_schema = len(spec['table']) > 1 and \ - spec['table'][-2] or schema + ref_name = spec["table"][-1] + ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema if not ref_schema: if default_schema is None: - default_schema = \ - connection.dialect.default_schema_name + default_schema = connection.dialect.default_schema_name if schema == default_schema: ref_schema = schema - loc_names = spec['local'] - ref_names = spec['foreign'] + loc_names = spec["local"] + ref_names = spec["foreign"] con_kw = {} - for opt in ('onupdate', 'ondelete'): + for opt in ("onupdate", "ondelete"): if spec.get(opt, False): con_kw[opt] = spec[opt] fkey_d = { - 'name': spec['name'], - 'constrained_columns': loc_names, - 'referred_schema': ref_schema, - 'referred_table': ref_name, - 'referred_columns': ref_names, - 'options': con_kw + "name": spec["name"], + "constrained_columns": loc_names, + "referred_schema": ref_schema, + "referred_table": ref_name, + "referred_columns": ref_names, + "options": con_kw, } fkeys.append(fkey_d) @@ -2172,25 +2548,26 @@ class MySQLDialect(default.DefaultDialect): default_schema_name = connection.dialect.default_schema_name col_tuples = [ ( - lower(rec['referred_schema'] or default_schema_name), - lower(rec['referred_table']), - col_name + lower(rec["referred_schema"] or default_schema_name), + lower(rec["referred_table"]), + col_name, ) for rec in fkeys - for col_name in rec['referred_columns'] + for col_name in rec["referred_columns"] ] if col_tuples: correct_for_wrong_fk_case = connection.execute( - sql.text(""" + sql.text( + """ select table_schema, table_name, column_name from information_schema.columns where (table_schema, table_name, lower(column_name)) in :table_data; - """).bindparams( - sql.bindparam("table_data", expanding=True) - ), table_data=col_tuples + """ + ).bindparams(sql.bindparam("table_data", expanding=True)), + table_data=col_tuples, ) # in casing=0, table name and schema name come back in their @@ -2208,109 +2585,117 @@ class MySQLDialect(default.DefaultDialect): d[(lower(schema), lower(tname))][cname.lower()] = cname for fkey in fkeys: - fkey['referred_columns'] = [ + fkey["referred_columns"] = [ d[ ( lower( - fkey['referred_schema'] or - default_schema_name), - lower(fkey['referred_table']) + fkey["referred_schema"] or default_schema_name + ), + lower(fkey["referred_table"]), ) ][col.lower()] - for col in fkey['referred_columns'] + for col in fkey["referred_columns"] ] @reflection.cache - def get_check_constraints( - self, connection, table_name, schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return [ - {"name": spec['name'], "sqltext": spec['sqltext']} + {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) - return {"text": parsed_state.table_options.get('mysql_comment', None)} + connection, table_name, schema, **kw + ) + return {"text": parsed_state.table_options.get("mysql_comment", None)} @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) indexes = [] for spec in parsed_state.keys: dialect_options = {} unique = False - flavor = spec['type'] - if flavor == 'PRIMARY': + flavor = spec["type"] + if flavor == "PRIMARY": continue - if flavor == 'UNIQUE': + if flavor == "UNIQUE": unique = True - elif flavor in ('FULLTEXT', 'SPATIAL'): + elif flavor in ("FULLTEXT", "SPATIAL"): dialect_options["mysql_prefix"] = flavor elif flavor is None: pass else: self.logger.info( - "Converting unknown KEY type %s to a plain KEY", flavor) + "Converting unknown KEY type %s to a plain KEY", flavor + ) pass - if spec['parser']: - dialect_options['mysql_with_parser'] = spec['parser'] + if spec["parser"]: + dialect_options["mysql_with_parser"] = spec["parser"] index_d = {} if dialect_options: index_d["dialect_options"] = dialect_options - index_d['name'] = spec['name'] - index_d['column_names'] = [s[0] for s in spec['columns']] - index_d['unique'] = unique + index_d["name"] = spec["name"] + index_d["column_names"] = [s[0] for s in spec["columns"]] + index_d["unique"] = unique if flavor: - index_d['type'] = flavor + index_d["type"] = flavor indexes.append(index_d) return indexes @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return [ { - 'name': key['name'], - 'column_names': [col[0] for col in key['columns']], - 'duplicates_index': key['name'], + "name": key["name"], + "column_names": [col[0] for col in key["columns"]], + "duplicates_index": key["name"], } for key in parsed_state.keys - if key['type'] == 'UNIQUE' + if key["type"] == "UNIQUE" ] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): charset = self._connection_charset - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, view_name)) - sql = self._show_create_table(connection, None, charset, - full_name=full_name) + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers(schema, view_name) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) return sql - def _parsed_state_or_create(self, connection, table_name, - schema=None, **kw): + def _parsed_state_or_create( + self, connection, table_name, schema=None, **kw + ): return self._setup_parser( connection, table_name, schema, - info_cache=kw.get('info_cache', None) + info_cache=kw.get("info_cache", None), ) @util.memoized_property @@ -2321,7 +2706,7 @@ class MySQLDialect(default.DefaultDialect): retrieved server version information first. """ - if (self.server_version_info < (4, 1) and self._server_ansiquotes): + if self.server_version_info < (4, 1) and self._server_ansiquotes: # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = self.preparer(self, server_ansiquotes=False) else: @@ -2332,14 +2717,19 @@ class MySQLDialect(default.DefaultDialect): def _setup_parser(self, connection, table_name, schema=None, **kw): charset = self._connection_charset parser = self._tabledef_parser - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) - sql = self._show_create_table(connection, None, charset, - full_name=full_name) - if re.match(r'^CREATE (?:ALGORITHM)?.* VIEW', sql): + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql): # Adapt views to something table-like. - columns = self._describe_table(connection, None, charset, - full_name=full_name) + columns = self._describe_table( + connection, None, charset, full_name=full_name + ) sql = parser._describe_to_create(table_name, columns) return parser.parse(sql, charset) @@ -2356,17 +2746,18 @@ class MySQLDialect(default.DefaultDialect): # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html charset = self._connection_charset - row = self._compat_first(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) + row = self._compat_first( + connection.execute("SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset, + ) if not row: cs = 0 else: # 4.0.15 returns OFF or ON according to [ticket:489] # 3.23 doesn't, 4.0.27 doesn't.. - if row[1] == 'OFF': + if row[1] == "OFF": cs = 0 - elif row[1] == 'ON': + elif row[1] == "ON": cs = 1 else: cs = int(row[1]) @@ -2384,7 +2775,7 @@ class MySQLDialect(default.DefaultDialect): pass else: charset = self._connection_charset - rs = connection.execute('SHOW COLLATION') + rs = connection.execute("SHOW COLLATION") for row in self._compat_fetchall(rs, charset): collations[row[0]] = row[1] return collations @@ -2392,33 +2783,36 @@ class MySQLDialect(default.DefaultDialect): def _detect_sql_mode(self, connection): row = self._compat_first( connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), - charset=self._connection_charset) + charset=self._connection_charset, + ) if not row: util.warn( "Could not retrieve SQL_MODE; please ensure the " - "MySQL user has permissions to SHOW VARIABLES") - self._sql_mode = '' + "MySQL user has permissions to SHOW VARIABLES" + ) + self._sql_mode = "" else: - self._sql_mode = row[1] or '' + self._sql_mode = row[1] or "" def _detect_ansiquotes(self, connection): """Detect and adjust for the ANSI_QUOTES sql mode.""" mode = self._sql_mode if not mode: - mode = '' + mode = "" elif mode.isdigit(): mode_no = int(mode) - mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' + mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or "" - self._server_ansiquotes = 'ANSI_QUOTES' in mode + self._server_ansiquotes = "ANSI_QUOTES" in mode # as of MySQL 5.0.1 - self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode + self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode - def _show_create_table(self, connection, table, charset=None, - full_name=None): + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): """Run SHOW CREATE TABLE for a ``Table``.""" if full_name is None: @@ -2428,7 +2822,8 @@ class MySQLDialect(default.DefaultDialect): rp = None try: rp = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) except exc.DBAPIError as e: if self._extract_error_code(e.orig) == 1146: raise exc.NoSuchTableError(full_name) @@ -2441,8 +2836,7 @@ class MySQLDialect(default.DefaultDialect): return sql - def _describe_table(self, connection, table, charset=None, - full_name=None): + def _describe_table(self, connection, table, charset=None, full_name=None): """Run DESCRIBE for a ``Table`` and return processed rows.""" if full_name is None: @@ -2453,7 +2847,8 @@ class MySQLDialect(default.DefaultDialect): try: try: rp = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) except exc.DBAPIError as e: code = self._extract_error_code(e.orig) if code == 1146: @@ -2486,11 +2881,11 @@ class _DecodingRowProxy(object): # seem to come up in DDL queries. _encoding_compat = { - 'koi8r': 'koi8_r', - 'koi8u': 'koi8_u', - 'utf16': 'utf-16-be', # MySQL's uft16 is always bigendian - 'utf8mb4': 'utf8', # real utf8 - 'eucjpms': 'ujis', + "koi8r": "koi8_r", + "koi8u": "koi8_u", + "utf16": "utf-16-be", # MySQL's uft16 is always bigendian + "utf8mb4": "utf8", # real utf8 + "eucjpms": "ujis", } def __init__(self, rowproxy, charset): |