diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 1397 |
1 files changed, 896 insertions, 501 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 673d4b9ff..7b0d0618c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -746,85 +746,340 @@ from ...engine import reflection from ...engine import default from ... import types as sqltypes from ...util import topological -from ...types import DATE, BOOLEAN, \ - BLOB, BINARY, VARBINARY +from ...types import DATE, BOOLEAN, BLOB, BINARY, VARBINARY from . import reflection as _reflection -from .types import BIGINT, BIT, CHAR, DECIMAL, DATETIME, \ - DOUBLE, FLOAT, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, \ - MEDIUMTEXT, NCHAR, NUMERIC, NVARCHAR, REAL, SMALLINT, TEXT, TIME, \ - TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT, VARCHAR, YEAR -from .types import _StringType, _IntegerType, _NumericType, \ - _FloatType, _MatchType +from .types import ( + BIGINT, + BIT, + CHAR, + DECIMAL, + DATETIME, + DOUBLE, + FLOAT, + INTEGER, + LONGBLOB, + LONGTEXT, + MEDIUMBLOB, + MEDIUMINT, + MEDIUMTEXT, + NCHAR, + NUMERIC, + NVARCHAR, + REAL, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + TINYBLOB, + TINYINT, + TINYTEXT, + VARCHAR, + YEAR, +) +from .types import ( + _StringType, + _IntegerType, + _NumericType, + _FloatType, + _MatchType, +) from .enumerated import ENUM, SET from .json import JSON, JSONIndexType, JSONPathType RESERVED_WORDS = set( - ['accessible', 'add', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', - 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', - 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', - 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', - 'current_user', 'cursor', 'database', 'databases', 'day_hour', - 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', - 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', - 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', - 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', - 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', - 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', - 'having', 'high_priority', 'hour_microsecond', 'hour_minute', - 'hour_second', 'if', 'ignore', 'in', 'index', 'infile', 'inner', 'inout', - 'insensitive', 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8', - 'integer', 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys', - 'kill', 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines', - 'load', 'localtime', 'localtimestamp', 'lock', 'long', 'longblob', - 'longtext', 'loop', 'low_priority', 'master_ssl_verify_server_cert', - 'match', 'mediumblob', 'mediumint', 'mediumtext', 'middleint', - 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural', - 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize', - 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile', - 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads', - 'read_only', 'read_write', 'real', 'references', 'regexp', 'release', - 'rename', 'repeat', 'replace', 'require', 'restrict', 'return', - 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond', - 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial', - 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning', - 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl', - 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob', - 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo', - 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', - 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', - 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', - - 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 - - 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1 - - 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', - 'read_only', 'read_write', # 5.1 - - 'general', 'ignore_server_ids', 'master_heartbeat_period', 'maxvalue', - 'resignal', 'signal', 'slow', # 5.5 - - 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot', - 'partition', 'sql_after_gtids', 'sql_before_gtids', # 5.6 - - 'generated', 'optimizer_costs', 'stored', 'virtual', # 5.7 - - 'admin', 'cume_dist', 'empty', 'except', 'first_value', 'grouping', - 'function', 'groups', 'json_table', 'last_value', 'nth_value', - 'ntile', 'of', 'over', 'percent_rank', 'persist', 'persist_only', - 'rank', 'recursive', 'role', 'row', 'rows', 'row_number', 'system', - 'window', # 8.0 - ]) + [ + "accessible", + "add", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "enclosed", + "escaped", + "exists", + "exit", + "explain", + "false", + "fetch", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "grant", + "group", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "interval", + "into", + "is", + "iterate", + "join", + "key", + "keys", + "kill", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_ssl_verify_server_cert", + "match", + "mediumblob", + "mediumint", + "mediumtext", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "not", + "no_write_to_binlog", + "null", + "numeric", + "on", + "optimize", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "precision", + "primary", + "procedure", + "purge", + "range", + "read", + "reads", + "read_only", + "read_write", + "real", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "restrict", + "return", + "revoke", + "right", + "rlike", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "smallint", + "spatial", + "specific", + "sql", + "sqlexception", + "sqlstate", + "sqlwarning", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "ssl", + "starting", + "straight_join", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "when", + "where", + "while", + "with", + "write", + "x509", + "xor", + "year_month", + "zerofill", # 5.0 + "columns", + "fields", + "privileges", + "soname", + "tables", # 4.1 + "accessible", + "linear", + "master_ssl_verify_server_cert", + "range", + "read_only", + "read_write", # 5.1 + "general", + "ignore_server_ids", + "master_heartbeat_period", + "maxvalue", + "resignal", + "signal", + "slow", # 5.5 + "get", + "io_after_gtids", + "io_before_gtids", + "master_bind", + "one_shot", + "partition", + "sql_after_gtids", + "sql_before_gtids", # 5.6 + "generated", + "optimizer_costs", + "stored", + "virtual", # 5.7 + "admin", + "cume_dist", + "empty", + "except", + "first_value", + "grouping", + "function", + "groups", + "json_table", + "last_value", + "nth_value", + "ntile", + "of", + "over", + "percent_rank", + "persist", + "persist_only", + "rank", + "recursive", + "role", + "row", + "rows", + "row_number", + "system", + "window", # 8.0 + ] +) AUTOCOMMIT_RE = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)", + re.I | re.UNICODE, +) SET_RE = re.compile( - r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', - re.I | re.UNICODE) + r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE +) # old names @@ -870,52 +1125,50 @@ colspecs = { sqltypes.MatchType: _MatchType, sqltypes.JSON: JSON, sqltypes.JSON.JSONIndexType: JSONIndexType, - sqltypes.JSON.JSONPathType: JSONPathType - + sqltypes.JSON.JSONPathType: JSONPathType, } # Everything 3.23 through 5.1 excepting OpenGIS types. ischema_names = { - 'bigint': BIGINT, - 'binary': BINARY, - 'bit': BIT, - 'blob': BLOB, - 'boolean': BOOLEAN, - 'char': CHAR, - 'date': DATE, - 'datetime': DATETIME, - 'decimal': DECIMAL, - 'double': DOUBLE, - 'enum': ENUM, - 'fixed': DECIMAL, - 'float': FLOAT, - 'int': INTEGER, - 'integer': INTEGER, - 'json': JSON, - 'longblob': LONGBLOB, - 'longtext': LONGTEXT, - 'mediumblob': MEDIUMBLOB, - 'mediumint': MEDIUMINT, - 'mediumtext': MEDIUMTEXT, - 'nchar': NCHAR, - 'nvarchar': NVARCHAR, - 'numeric': NUMERIC, - 'set': SET, - 'smallint': SMALLINT, - 'text': TEXT, - 'time': TIME, - 'timestamp': TIMESTAMP, - 'tinyblob': TINYBLOB, - 'tinyint': TINYINT, - 'tinytext': TINYTEXT, - 'varbinary': VARBINARY, - 'varchar': VARCHAR, - 'year': YEAR, + "bigint": BIGINT, + "binary": BINARY, + "bit": BIT, + "blob": BLOB, + "boolean": BOOLEAN, + "char": CHAR, + "date": DATE, + "datetime": DATETIME, + "decimal": DECIMAL, + "double": DOUBLE, + "enum": ENUM, + "fixed": DECIMAL, + "float": FLOAT, + "int": INTEGER, + "integer": INTEGER, + "json": JSON, + "longblob": LONGBLOB, + "longtext": LONGTEXT, + "mediumblob": MEDIUMBLOB, + "mediumint": MEDIUMINT, + "mediumtext": MEDIUMTEXT, + "nchar": NCHAR, + "nvarchar": NVARCHAR, + "numeric": NUMERIC, + "set": SET, + "smallint": SMALLINT, + "text": TEXT, + "time": TIME, + "timestamp": TIMESTAMP, + "tinyblob": TINYBLOB, + "tinyint": TINYINT, + "tinytext": TINYTEXT, + "varbinary": VARBINARY, + "varchar": VARCHAR, + "year": YEAR, } class MySQLExecutionContext(default.DefaultExecutionContext): - def should_autocommit_text(self, statement): return AUTOCOMMIT_RE.match(statement) @@ -932,7 +1185,7 @@ class MySQLCompiler(compiler.SQLCompiler): """Overridden from base SQLCompiler value""" extract_map = compiler.SQLCompiler.extract_map.copy() - extract_map.update({'milliseconds': 'millisecond'}) + extract_map.update({"milliseconds": "millisecond"}) def visit_random_func(self, fn, **kw): return "rand%s" % self.function_argspec(fn) @@ -943,12 +1196,14 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_on_duplicate_key_update(self, on_duplicate, **kw): if on_duplicate._parameter_ordering: @@ -958,7 +1213,8 @@ class MySQLCompiler(compiler.SQLCompiler): ] ordered_keys = set(parameter_ordering) cols = [ - self.statement.table.c[key] for key in parameter_ordering + self.statement.table.c[key] + for key in parameter_ordering if key in self.statement.table.c ] + [ c for c in self.statement.table.c if c.key not in ordered_keys @@ -979,9 +1235,11 @@ class MySQLCompiler(compiler.SQLCompiler): val = val._clone() val.type = column.type value_text = self.process(val.self_group(), use_schema=False) - elif isinstance(val, elements.ColumnClause) \ - and val.table is on_duplicate.inserted_alias: - value_text = 'VALUES(' + self.preparer.quote(column.name) + ')' + elif ( + isinstance(val, elements.ColumnClause) + and val.table is on_duplicate.inserted_alias + ): + value_text = "VALUES(" + self.preparer.quote(column.name) + ")" else: value_text = self.process(val.self_group(), use_schema=False) name_text = self.preparer.quote(column.name) @@ -990,22 +1248,27 @@ class MySQLCompiler(compiler.SQLCompiler): non_matching = set(on_duplicate.update) - set(c.key for c in cols) if non_matching: util.warn( - 'Additional column names not matching ' - "any column keys in table '%s': %s" % ( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( self.statement.table.name, - (', '.join("'%s'" % c for c in non_matching)) + (", ".join("'%s'" % c for c in non_matching)), ) ) - return 'ON DUPLICATE KEY UPDATE ' + ', '.join(clauses) + return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses) def visit_concat_op_binary(self, binary, operator, **kw): - return "concat(%s, %s)" % (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "concat(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_match_op_binary(self, binary, operator, **kw): - return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \ - (self.process(binary.left, **kw), self.process(binary.right, **kw)) + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def get_from_hint_text(self, table, text): return text @@ -1016,26 +1279,35 @@ class MySQLCompiler(compiler.SQLCompiler): if isinstance(type_, sqltypes.TypeDecorator): return self.visit_typeclause(typeclause, type_.impl, **kw) elif isinstance(type_, sqltypes.Integer): - if getattr(type_, 'unsigned', False): - return 'UNSIGNED INTEGER' + if getattr(type_, "unsigned", False): + return "UNSIGNED INTEGER" else: - return 'SIGNED INTEGER' + return "SIGNED INTEGER" elif isinstance(type_, sqltypes.TIMESTAMP): - return 'DATETIME' - elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, - sqltypes.Date, sqltypes.Time)): + return "DATETIME" + elif isinstance( + type_, + ( + sqltypes.DECIMAL, + sqltypes.DateTime, + sqltypes.Date, + sqltypes.Time, + ), + ): return self.dialect.type_compiler.process(type_) - elif isinstance(type_, sqltypes.String) \ - and not isinstance(type_, (ENUM, SET)): + elif isinstance(type_, sqltypes.String) and not isinstance( + type_, (ENUM, SET) + ): adapted = CHAR._adapt_string_for_cast(type_) return self.dialect.type_compiler.process(adapted) elif isinstance(type_, sqltypes._Binary): - return 'BINARY' + return "BINARY" elif isinstance(type_, sqltypes.JSON): return "JSON" elif isinstance(type_, sqltypes.NUMERIC): - return self.dialect.type_compiler.process( - type_).replace('NUMERIC', 'DECIMAL') + return self.dialect.type_compiler.process(type_).replace( + "NUMERIC", "DECIMAL" + ) else: return None @@ -1044,23 +1316,25 @@ class MySQLCompiler(compiler.SQLCompiler): if not self.dialect._supports_cast: util.warn( "Current MySQL version does not support " - "CAST; the CAST will be skipped.") + "CAST; the CAST will be skipped." + ) return self.process(cast.clause.self_group(), **kw) type_ = self.process(cast.typeclause) if type_ is None: util.warn( "Datatype %s does not support CAST on MySQL; " - "the CAST will be skipped." % - self.dialect.type_compiler.process(cast.typeclause.type)) + "the CAST will be skipped." + % self.dialect.type_compiler.process(cast.typeclause.type) + ) return self.process(cast.clause.self_group(), **kw) - return 'CAST(%s AS %s)' % (self.process(cast.clause, **kw), type_) + return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) def render_literal_value(self, value, type_): value = super(MySQLCompiler, self).render_literal_value(value, type_) if self.dialect._backslash_escapes: - value = value.replace('\\', '\\\\') + value = value.replace("\\", "\\\\") return value # override native_boolean=False behavior here, as @@ -1096,12 +1370,15 @@ class MySQLCompiler(compiler.SQLCompiler): else: join_type = " INNER JOIN " - return ''.join( - (self.process(join.left, asfrom=True, **kwargs), - join_type, - self.process(join.right, asfrom=True, **kwargs), - " ON ", - self.process(join.onclause, **kwargs))) + return "".join( + ( + self.process(join.left, asfrom=True, **kwargs), + join_type, + self.process(join.right, asfrom=True, **kwargs), + " ON ", + self.process(join.onclause, **kwargs), + ) + ) def for_update_clause(self, select, **kw): if select._for_update_arg.read: @@ -1118,11 +1395,13 @@ class MySQLCompiler(compiler.SQLCompiler): # The latter is more readable for offsets but we're stuck with the # former until we can refine dialects by server revision. - limit_clause, offset_clause = select._limit_clause, \ - select._offset_clause + limit_clause, offset_clause = ( + select._limit_clause, + select._offset_clause, + ) if limit_clause is None and offset_clause is None: - return '' + return "" elif offset_clause is not None: # As suggested by the MySQL docs, need to apply an # artificial limit if one wasn't provided @@ -1134,35 +1413,38 @@ class MySQLCompiler(compiler.SQLCompiler): # but also is consistent with the usage of the upper # bound as part of MySQL's "syntax" for OFFSET with # no LIMIT - return ' \n LIMIT %s, %s' % ( + return " \n LIMIT %s, %s" % ( self.process(offset_clause, **kw), - "18446744073709551615") + "18446744073709551615", + ) else: - return ' \n LIMIT %s, %s' % ( + return " \n LIMIT %s, %s" % ( self.process(offset_clause, **kw), - self.process(limit_clause, **kw)) + self.process(limit_clause, **kw), + ) else: # No offset provided, so just use the limit - return ' \n LIMIT %s' % (self.process(limit_clause, **kw),) + return " \n LIMIT %s" % (self.process(limit_clause, **kw),) def update_limit_clause(self, update_stmt): - limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None) + limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) if limit: return "LIMIT %s" % limit else: return None - def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): - return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) - for t in [from_table] + list(extra_froms)) + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + return ", ".join( + t._compiler_dispatch(self, asfrom=True, **kw) + for t in [from_table] + list(extra_froms) + ) - def update_from_clause(self, update_stmt, from_table, - extra_froms, from_hints, **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): return None - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1171,24 +1453,27 @@ class MySQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. USING clause specific to MySQL.""" - return "USING " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "USING " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) def visit_empty_set_expr(self, element_types): return ( "SELECT %(outer)s FROM (SELECT %(inner)s) " - "as _empty_set WHERE 1!=1" % { + "as _empty_set WHERE 1!=1" + % { "inner": ", ".join( "1 AS _in_%s" % idx - for idx, type_ in enumerate(element_types)), + for idx, type_ in enumerate(element_types) + ), "outer": ", ".join( - "_in_%s" % idx - for idx, type_ in enumerate(element_types)) + "_in_%s" % idx for idx, type_ in enumerate(element_types) + ), } ) @@ -1200,35 +1485,39 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec = [ self.preparer.format_column(column), self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ), ] is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP) if not column.nullable: - colspec.append('NOT NULL') + colspec.append("NOT NULL") # see: http://docs.sqlalchemy.org/en/latest/dialects/ # mysql.html#mysql_timestamp_null elif column.nullable and is_timestamp: - colspec.append('NULL') + colspec.append("NULL") default = self.get_column_default_string(column) if default is not None: - colspec.append('DEFAULT ' + default) + colspec.append("DEFAULT " + default) comment = column.comment if comment is not None: literal = self.sql_compiler.render_literal_value( - comment, sqltypes.String()) - colspec.append('COMMENT ' + literal) + comment, sqltypes.String() + ) + colspec.append("COMMENT " + literal) - if column.table is not None \ - and column is column.table._autoincrement_column and \ - column.server_default is None: - colspec.append('AUTO_INCREMENT') + if ( + column.table is not None + and column is column.table._autoincrement_column + and column.server_default is None + ): + colspec.append("AUTO_INCREMENT") - return ' '.join(colspec) + return " ".join(colspec) def post_create_table(self, table): """Build table-level CREATE options like ENGINE and COLLATE.""" @@ -1236,76 +1525,94 @@ class MySQLDDLCompiler(compiler.DDLCompiler): table_opts = [] opts = dict( - ( - k[len(self.dialect.name) + 1:].upper(), - v - ) + (k[len(self.dialect.name) + 1 :].upper(), v) for k, v in table.kwargs.items() - if k.startswith('%s_' % self.dialect.name) + if k.startswith("%s_" % self.dialect.name) ) if table.comment is not None: - opts['COMMENT'] = table.comment + opts["COMMENT"] = table.comment partition_options = [ - 'PARTITION_BY', 'PARTITIONS', 'SUBPARTITIONS', - 'SUBPARTITION_BY' + "PARTITION_BY", + "PARTITIONS", + "SUBPARTITIONS", + "SUBPARTITION_BY", ] nonpart_options = set(opts).difference(partition_options) part_options = set(opts).intersection(partition_options) - for opt in topological.sort([ - ('DEFAULT_CHARSET', 'COLLATE'), - ('DEFAULT_CHARACTER_SET', 'COLLATE'), - ], nonpart_options): + for opt in topological.sort( + [ + ("DEFAULT_CHARSET", "COLLATE"), + ("DEFAULT_CHARACTER_SET", "COLLATE"), + ], + nonpart_options, + ): arg = opts[opt] if opt in _reflection._options_of_type_string: arg = self.sql_compiler.render_literal_value( - arg, sqltypes.String()) - - if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY', - 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET', - 'DEFAULT_CHARSET', - 'DEFAULT_COLLATE'): - opt = opt.replace('_', ' ') + arg, sqltypes.String() + ) - joiner = '=' - if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', - 'CHARACTER SET', 'COLLATE'): - joiner = ' ' + if opt in ( + "DATA_DIRECTORY", + "INDEX_DIRECTORY", + "DEFAULT_CHARACTER_SET", + "CHARACTER_SET", + "DEFAULT_CHARSET", + "DEFAULT_COLLATE", + ): + opt = opt.replace("_", " ") + + joiner = "=" + if opt in ( + "TABLESPACE", + "DEFAULT CHARACTER SET", + "CHARACTER SET", + "COLLATE", + ): + joiner = " " table_opts.append(joiner.join((opt, arg))) - for opt in topological.sort([ - ('PARTITION_BY', 'PARTITIONS'), - ('PARTITION_BY', 'SUBPARTITION_BY'), - ('PARTITION_BY', 'SUBPARTITIONS'), - ('PARTITIONS', 'SUBPARTITIONS'), - ('PARTITIONS', 'SUBPARTITION_BY'), - ('SUBPARTITION_BY', 'SUBPARTITIONS') - ], part_options): + for opt in topological.sort( + [ + ("PARTITION_BY", "PARTITIONS"), + ("PARTITION_BY", "SUBPARTITION_BY"), + ("PARTITION_BY", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITION_BY"), + ("SUBPARTITION_BY", "SUBPARTITIONS"), + ], + part_options, + ): arg = opts[opt] if opt in _reflection._options_of_type_string: arg = self.sql_compiler.render_literal_value( - arg, sqltypes.String()) + arg, sqltypes.String() + ) - opt = opt.replace('_', ' ') - joiner = ' ' + opt = opt.replace("_", " ") + joiner = " " table_opts.append(joiner.join((opt, arg))) - return ' '.join(table_opts) + return " ".join(table_opts) def visit_create_index(self, create, **kw): index = create.element self._verify_index_table(index) preparer = self.preparer table = preparer.format_table(index.table) - columns = [self.sql_compiler.process(expr, include_table=False, - literal_binds=True) - for expr in index.expressions] + columns = [ + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ] name = self._prepared_index_name(index) @@ -1313,53 +1620,54 @@ class MySQLDDLCompiler(compiler.DDLCompiler): if index.unique: text += "UNIQUE " - index_prefix = index.kwargs.get('mysql_prefix', None) + index_prefix = index.kwargs.get("mysql_prefix", None) if index_prefix: - text += index_prefix + ' ' + text += index_prefix + " " text += "INDEX %s ON %s " % (name, table) - length = index.dialect_options['mysql']['length'] + length = index.dialect_options["mysql"]["length"] if length is not None: if isinstance(length, dict): # length value can be a (column_name --> integer value) # mapping specifying the prefix length for each column of the # index - columns = ', '.join( - '%s(%d)' % (expr, length[col.name]) if col.name in length - else - ( - '%s(%d)' % (expr, length[expr]) if expr in length - else '%s' % expr + columns = ", ".join( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr ) for col, expr in zip(index.expressions, columns) ) else: # or can be an integer value specifying the same # prefix length for all columns of the index - columns = ', '.join( - '%s(%d)' % (col, length) - for col in columns + columns = ", ".join( + "%s(%d)" % (col, length) for col in columns ) else: - columns = ', '.join(columns) - text += '(%s)' % columns + columns = ", ".join(columns) + text += "(%s)" % columns - parser = index.dialect_options['mysql']['with_parser'] + parser = index.dialect_options["mysql"]["with_parser"] if parser is not None: - text += " WITH PARSER %s" % (parser, ) + text += " WITH PARSER %s" % (parser,) - using = index.dialect_options['mysql']['using'] + using = index.dialect_options["mysql"]["using"] if using is not None: text += " USING %s" % (preparer.quote(using)) return text def visit_primary_key_constraint(self, constraint): - text = super(MySQLDDLCompiler, self).\ - visit_primary_key_constraint(constraint) - using = constraint.dialect_options['mysql']['using'] + text = super(MySQLDDLCompiler, self).visit_primary_key_constraint( + constraint + ) + using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) return text @@ -1368,9 +1676,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler): index = drop.element return "\nDROP INDEX %s ON %s" % ( - self._prepared_index_name(index, - include_schema=False), - self.preparer.format_table(index.table)) + self._prepared_index_name(index, include_schema=False), + self.preparer.format_table(index.table), + ) def visit_drop_constraint(self, drop): constraint = drop.element @@ -1386,29 +1694,33 @@ class MySQLDDLCompiler(compiler.DDLCompiler): else: qual = "" const = self.preparer.format_constraint(constraint) - return "ALTER TABLE %s DROP %s%s" % \ - (self.preparer.format_table(constraint.table), - qual, const) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) def define_constraint_match(self, constraint): if constraint.match is not None: raise exc.CompileError( "MySQL ignores the 'MATCH' keyword while at the same time " - "causes ON UPDATE/ON DELETE clauses to be ignored.") + "causes ON UPDATE/ON DELETE clauses to be ignored." + ) return "" def visit_set_table_comment(self, create): return "ALTER TABLE %s COMMENT %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_set_column_comment(self, create): return "ALTER TABLE %s CHANGE %s %s" % ( self.preparer.format_table(create.element.table), self.preparer.format_column(create.element), - self.get_column_specification(create.element) + self.get_column_specification(create.element), ) @@ -1420,9 +1732,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return spec if type_.unsigned: - spec += ' UNSIGNED' + spec += " UNSIGNED" if type_.zerofill: - spec += ' ZEROFILL' + spec += " ZEROFILL" return spec def _extend_string(self, type_, defaults, spec): @@ -1434,28 +1746,30 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def attr(name): return getattr(type_, name, defaults.get(name)) - if attr('charset'): - charset = 'CHARACTER SET %s' % attr('charset') - elif attr('ascii'): - charset = 'ASCII' - elif attr('unicode'): - charset = 'UNICODE' + if attr("charset"): + charset = "CHARACTER SET %s" % attr("charset") + elif attr("ascii"): + charset = "ASCII" + elif attr("unicode"): + charset = "UNICODE" else: charset = None - if attr('collation'): - collation = 'COLLATE %s' % type_.collation - elif attr('binary'): - collation = 'BINARY' + if attr("collation"): + collation = "COLLATE %s" % type_.collation + elif attr("binary"): + collation = "BINARY" else: collation = None - if attr('national'): + if attr("national"): # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. - return ' '.join([c for c in ('NATIONAL', spec, collation) - if c is not None]) - return ' '.join([c for c in (spec, charset, collation) - if c is not None]) + return " ".join( + [c for c in ("NATIONAL", spec, collation) if c is not None] + ) + return " ".join( + [c for c in (spec, charset, collation) if c is not None] + ) def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) @@ -1464,95 +1778,113 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: - return self._extend_numeric(type_, - "NUMERIC(%(precision)s)" % - {'precision': type_.precision}) + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s)" % {"precision": type_.precision}, + ) else: - return self._extend_numeric(type_, - "NUMERIC(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: - return self._extend_numeric(type_, - "DECIMAL(%(precision)s)" % - {'precision': type_.precision}) + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s)" % {"precision": type_.precision}, + ) else: - return self._extend_numeric(type_, - "DECIMAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: - return self._extend_numeric(type_, - "DOUBLE(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "DOUBLE(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) else: - return self._extend_numeric(type_, 'DOUBLE') + return self._extend_numeric(type_, "DOUBLE") def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: - return self._extend_numeric(type_, - "REAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "REAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) else: - return self._extend_numeric(type_, 'REAL') + return self._extend_numeric(type_, "REAL") def visit_FLOAT(self, type_, **kw): - if self._mysql_type(type_) and \ - type_.scale is not None and \ - type_.precision is not None: + if ( + self._mysql_type(type_) + and type_.scale is not None + and type_.precision is not None + ): return self._extend_numeric( - type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale) + ) elif type_.precision is not None: - return self._extend_numeric(type_, - "FLOAT(%s)" % (type_.precision,)) + return self._extend_numeric( + type_, "FLOAT(%s)" % (type_.precision,) + ) else: return self._extend_numeric(type_, "FLOAT") def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "INTEGER(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "INTEGER(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "INTEGER") def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "BIGINT(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "BIGINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "BIGINT") def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "MEDIUMINT(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "MEDIUMINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "MEDIUMINT") def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: - return self._extend_numeric(type_, - "TINYINT(%s)" % type_.display_width) + return self._extend_numeric( + type_, "TINYINT(%s)" % type_.display_width + ) else: return self._extend_numeric(type_, "TINYINT") def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: - return self._extend_numeric(type_, - "SMALLINT(%(display_width)s)" % - {'display_width': type_.display_width} - ) + return self._extend_numeric( + type_, + "SMALLINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "SMALLINT") @@ -1563,7 +1895,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "BIT" def visit_DATETIME(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "DATETIME(%d)" % type_.fsp else: return "DATETIME" @@ -1572,13 +1904,13 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "DATE" def visit_TIME(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "TIME(%d)" % type_.fsp else: return "TIME" def visit_TIMESTAMP(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" @@ -1606,17 +1938,17 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_VARCHAR(self, type_, **kw): if type_.length: - return self._extend_string( - type_, {}, "VARCHAR(%d)" % type_.length) + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) def visit_CHAR(self, type_, **kw): if type_.length: - return self._extend_string(type_, {}, "CHAR(%(length)s)" % - {'length': type_.length}) + return self._extend_string( + type_, {}, "CHAR(%(length)s)" % {"length": type_.length} + ) else: return self._extend_string(type_, {}, "CHAR") @@ -1625,22 +1957,26 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): # of "NVARCHAR". if type_.length: return self._extend_string( - type_, {'national': True}, - "VARCHAR(%(length)s)" % {'length': type_.length}) + type_, + {"national": True}, + "VARCHAR(%(length)s)" % {"length": type_.length}, + ) else: raise exc.CompileError( - "NVARCHAR requires a length on dialect %s" % - self.dialect.name) + "NVARCHAR requires a length on dialect %s" % self.dialect.name + ) def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: return self._extend_string( - type_, {'national': True}, - "CHAR(%(length)s)" % {'length': type_.length}) + type_, + {"national": True}, + "CHAR(%(length)s)" % {"length": type_.length}, + ) else: - return self._extend_string(type_, {'national': True}, "CHAR") + return self._extend_string(type_, {"national": True}, "CHAR") def visit_VARBINARY(self, type_, **kw): return "VARBINARY(%d)" % type_.length @@ -1676,17 +2012,19 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): quoted_enums = [] for e in enumerated_values: quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend_string(type_, {}, "%s(%s)" % ( - name, ",".join(quoted_enums)) + return self._extend_string( + type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) ) def visit_ENUM(self, type_, **kw): - return self._visit_enumerated_values("ENUM", type_, - type_._enumerated_values) + return self._visit_enumerated_values( + "ENUM", type_, type_._enumerated_values + ) def visit_SET(self, type_, **kw): - return self._visit_enumerated_values("SET", type_, - type_._enumerated_values) + return self._visit_enumerated_values( + "SET", type_, type_._enumerated_values + ) def visit_BOOLEAN(self, type, **kw): return "BOOL" @@ -1703,9 +2041,8 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): quote = '"' super(MySQLIdentifierPreparer, self).__init__( - dialect, - initial_quote=quote, - escape_quote=quote) + dialect, initial_quote=quote, escape_quote=quote + ) def _quote_free_identifiers(self, *ids): """Unilaterally identifier-quote any number of strings.""" @@ -1719,7 +2056,7 @@ class MySQLDialect(default.DefaultDialect): Not used directly in application code. """ - name = 'mysql' + name = "mysql" supports_alter = True # MySQL has no true "boolean" type; we @@ -1738,7 +2075,7 @@ class MySQLDialect(default.DefaultDialect): supports_comments = True inline_comments = True - default_paramstyle = 'format' + default_paramstyle = "format" colspecs = colspecs cte_follows_insert = True @@ -1756,26 +2093,28 @@ class MySQLDialect(default.DefaultDialect): _server_ansiquotes = False construct_arguments = [ - (sa_schema.Table, { - "*": None - }), - (sql.Update, { - "limit": None - }), - (sa_schema.PrimaryKeyConstraint, { - "using": None - }), - (sa_schema.Index, { - "using": None, - "length": None, - "prefix": None, - "with_parser": None - }) + (sa_schema.Table, {"*": None}), + (sql.Update, {"limit": None}), + (sa_schema.PrimaryKeyConstraint, {"using": None}), + ( + sa_schema.Index, + { + "using": None, + "length": None, + "prefix": None, + "with_parser": None, + }, + ), ] - def __init__(self, isolation_level=None, json_serializer=None, - json_deserializer=None, **kwargs): - kwargs.pop('use_ansiquotes', None) # legacy + def __init__( + self, + isolation_level=None, + json_serializer=None, + json_deserializer=None, + **kwargs + ): + kwargs.pop("use_ansiquotes", None) # legacy default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_serializer = json_serializer @@ -1783,22 +2122,30 @@ class MySQLDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" # to work properly - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection self._set_isolation_level(connection, level) @@ -1807,8 +2154,8 @@ class MySQLDialect(default.DefaultDialect): if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) @@ -1818,9 +2165,9 @@ class MySQLDialect(default.DefaultDialect): def get_isolation_level(self, connection): cursor = connection.cursor() if self._is_mysql and self.server_version_info >= (5, 7, 20): - cursor.execute('SELECT @@transaction_isolation') + cursor.execute("SELECT @@transaction_isolation") else: - cursor.execute('SELECT @@tx_isolation') + cursor.execute("SELECT @@tx_isolation") val = cursor.fetchone()[0] cursor.close() if util.py3k and isinstance(val, bytes): @@ -1840,7 +2187,7 @@ class MySQLDialect(default.DefaultDialect): val = val.decode() version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(val): try: version.append(int(n)) @@ -1885,29 +2232,38 @@ class MySQLDialect(default.DefaultDialect): connection.execute(sql.text("XA END :xid"), xid=xid) connection.execute(sql.text("XA PREPARE :xid"), xid=xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: connection.execute(sql.text("XA END :xid"), xid=xid) connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid) - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_prepare_twophase(connection, xid) connection.execute(sql.text("XA COMMIT :xid"), xid=xid) def do_recover_twophase(self, connection): resultset = connection.execute("XA RECOVER") - return [row['data'][0:row['gtrid_length']] for row in resultset] + return [row["data"][0 : row["gtrid_length"]] for row in resultset] def is_disconnect(self, e, connection, cursor): - if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): - return self._extract_error_code(e) in \ - (2006, 2013, 2014, 2045, 2055) + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) elif isinstance( - e, (self.dbapi.InterfaceError, self.dbapi.InternalError)): + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) + ): # if underlying connection is closed, # this is the error you get return "(0, '')" in str(e) @@ -1944,7 +2300,7 @@ class MySQLDialect(default.DefaultDialect): raise NotImplementedError() def _get_default_schema_name(self, connection): - return connection.execute('SELECT DATABASE()').scalar() + return connection.execute("SELECT DATABASE()").scalar() def has_table(self, connection, table_name, schema=None): # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly @@ -1957,15 +2313,19 @@ class MySQLDialect(default.DefaultDialect): # full_name = self.identifier_preparer.format_table(table, # use_schema=True) - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) st = "DESCRIBE %s" % full_name rs = None try: try: rs = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) have = rs.fetchone() is not None rs.close() return have @@ -1986,12 +2346,13 @@ class MySQLDialect(default.DefaultDialect): # if ansiquotes == True, build a new IdentifierPreparer # with the new setting self.identifier_preparer = self.preparer( - self, server_ansiquotes=self._server_ansiquotes) + self, server_ansiquotes=self._server_ansiquotes + ) default.DefaultDialect.initialize(self, connection) self._needs_correct_for_88718 = ( - not self._is_mariadb and self.server_version_info >= (8, ) + not self._is_mariadb and self.server_version_info >= (8,) ) self._warn_for_known_db_issues() @@ -2007,20 +2368,23 @@ class MySQLDialect(default.DefaultDialect): "additional issue prevents proper migrations of columns " "with CHECK constraints (MDEV-11114). Please upgrade to " "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " - "series, to avoid these issues." % (mdb_version, )) + "series, to avoid these issues." % (mdb_version,) + ) @property def _is_mariadb(self): - return 'MariaDB' in self.server_version_info + return "MariaDB" in self.server_version_info @property def _is_mysql(self): - return 'MariaDB' not in self.server_version_info + return "MariaDB" not in self.server_version_info @property def _is_mariadb_102(self): - return self._is_mariadb and \ - self._mariadb_normalized_version_info > (10, 2) + return self._is_mariadb and self._mariadb_normalized_version_info > ( + 10, + 2, + ) @property def _mariadb_normalized_version_info(self): @@ -2028,15 +2392,17 @@ class MySQLDialect(default.DefaultDialect): # the string "5.5"; now that we use @@version we no longer see this. if self._is_mariadb: - idx = self.server_version_info.index('MariaDB') - return self.server_version_info[idx - 3: idx] + idx = self.server_version_info.index("MariaDB") + return self.server_version_info[idx - 3 : idx] else: return self.server_version_info @property def _supports_cast(self): - return self.server_version_info is None or \ - self.server_version_info >= (4, 0, 2) + return ( + self.server_version_info is None + or self.server_version_info >= (4, 0, 2) + ) @reflection.cache def get_schema_names(self, connection, **kw): @@ -2054,18 +2420,23 @@ class MySQLDialect(default.DefaultDialect): charset = self._connection_charset if self.server_version_info < (5, 0, 2): rp = connection.execute( - "SHOW TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) - return [row[0] for - row in self._compat_fetchall(rp, charset=charset)] + "SHOW TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) + return [ + row[0] for row in self._compat_fetchall(rp, charset=charset) + ] else: rp = connection.execute( - "SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) - return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] == 'BASE TABLE'] + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == "BASE TABLE" + ] @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -2077,72 +2448,77 @@ class MySQLDialect(default.DefaultDialect): return self.get_table_names(connection, schema) charset = self._connection_charset rp = connection.execute( - "SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(schema)) - return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] in ('VIEW', 'SYSTEM VIEW')] + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(schema) + ) + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ("VIEW", "SYSTEM VIEW") + ] @reflection.cache def get_table_options(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return parsed_state.table_options @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return parsed_state.columns @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) for key in parsed_state.keys: - if key['type'] == 'PRIMARY': + if key["type"] == "PRIMARY": # There can be only one. - cols = [s[0] for s in key['columns']] - return {'constrained_columns': cols, 'name': None} - return {'constrained_columns': [], 'name': None} + cols = [s[0] for s in key["columns"]] + return {"constrained_columns": cols, "name": None} + return {"constrained_columns": [], "name": None} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) default_schema = None fkeys = [] for spec in parsed_state.fk_constraints: - ref_name = spec['table'][-1] - ref_schema = len(spec['table']) > 1 and \ - spec['table'][-2] or schema + ref_name = spec["table"][-1] + ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema if not ref_schema: if default_schema is None: - default_schema = \ - connection.dialect.default_schema_name + default_schema = connection.dialect.default_schema_name if schema == default_schema: ref_schema = schema - loc_names = spec['local'] - ref_names = spec['foreign'] + loc_names = spec["local"] + ref_names = spec["foreign"] con_kw = {} - for opt in ('onupdate', 'ondelete'): + for opt in ("onupdate", "ondelete"): if spec.get(opt, False): con_kw[opt] = spec[opt] fkey_d = { - 'name': spec['name'], - 'constrained_columns': loc_names, - 'referred_schema': ref_schema, - 'referred_table': ref_name, - 'referred_columns': ref_names, - 'options': con_kw + "name": spec["name"], + "constrained_columns": loc_names, + "referred_schema": ref_schema, + "referred_table": ref_name, + "referred_columns": ref_names, + "options": con_kw, } fkeys.append(fkey_d) @@ -2172,25 +2548,26 @@ class MySQLDialect(default.DefaultDialect): default_schema_name = connection.dialect.default_schema_name col_tuples = [ ( - lower(rec['referred_schema'] or default_schema_name), - lower(rec['referred_table']), - col_name + lower(rec["referred_schema"] or default_schema_name), + lower(rec["referred_table"]), + col_name, ) for rec in fkeys - for col_name in rec['referred_columns'] + for col_name in rec["referred_columns"] ] if col_tuples: correct_for_wrong_fk_case = connection.execute( - sql.text(""" + sql.text( + """ select table_schema, table_name, column_name from information_schema.columns where (table_schema, table_name, lower(column_name)) in :table_data; - """).bindparams( - sql.bindparam("table_data", expanding=True) - ), table_data=col_tuples + """ + ).bindparams(sql.bindparam("table_data", expanding=True)), + table_data=col_tuples, ) # in casing=0, table name and schema name come back in their @@ -2208,109 +2585,117 @@ class MySQLDialect(default.DefaultDialect): d[(lower(schema), lower(tname))][cname.lower()] = cname for fkey in fkeys: - fkey['referred_columns'] = [ + fkey["referred_columns"] = [ d[ ( lower( - fkey['referred_schema'] or - default_schema_name), - lower(fkey['referred_table']) + fkey["referred_schema"] or default_schema_name + ), + lower(fkey["referred_table"]), ) ][col.lower()] - for col in fkey['referred_columns'] + for col in fkey["referred_columns"] ] @reflection.cache - def get_check_constraints( - self, connection, table_name, schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return [ - {"name": spec['name'], "sqltext": spec['sqltext']} + {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) - return {"text": parsed_state.table_options.get('mysql_comment', None)} + connection, table_name, schema, **kw + ) + return {"text": parsed_state.table_options.get("mysql_comment", None)} @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) indexes = [] for spec in parsed_state.keys: dialect_options = {} unique = False - flavor = spec['type'] - if flavor == 'PRIMARY': + flavor = spec["type"] + if flavor == "PRIMARY": continue - if flavor == 'UNIQUE': + if flavor == "UNIQUE": unique = True - elif flavor in ('FULLTEXT', 'SPATIAL'): + elif flavor in ("FULLTEXT", "SPATIAL"): dialect_options["mysql_prefix"] = flavor elif flavor is None: pass else: self.logger.info( - "Converting unknown KEY type %s to a plain KEY", flavor) + "Converting unknown KEY type %s to a plain KEY", flavor + ) pass - if spec['parser']: - dialect_options['mysql_with_parser'] = spec['parser'] + if spec["parser"]: + dialect_options["mysql_with_parser"] = spec["parser"] index_d = {} if dialect_options: index_d["dialect_options"] = dialect_options - index_d['name'] = spec['name'] - index_d['column_names'] = [s[0] for s in spec['columns']] - index_d['unique'] = unique + index_d["name"] = spec["name"] + index_d["column_names"] = [s[0] for s in spec["columns"]] + index_d["unique"] = unique if flavor: - index_d['type'] = flavor + index_d["type"] = flavor indexes.append(index_d) return indexes @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return [ { - 'name': key['name'], - 'column_names': [col[0] for col in key['columns']], - 'duplicates_index': key['name'], + "name": key["name"], + "column_names": [col[0] for col in key["columns"]], + "duplicates_index": key["name"], } for key in parsed_state.keys - if key['type'] == 'UNIQUE' + if key["type"] == "UNIQUE" ] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): charset = self._connection_charset - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, view_name)) - sql = self._show_create_table(connection, None, charset, - full_name=full_name) + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers(schema, view_name) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) return sql - def _parsed_state_or_create(self, connection, table_name, - schema=None, **kw): + def _parsed_state_or_create( + self, connection, table_name, schema=None, **kw + ): return self._setup_parser( connection, table_name, schema, - info_cache=kw.get('info_cache', None) + info_cache=kw.get("info_cache", None), ) @util.memoized_property @@ -2321,7 +2706,7 @@ class MySQLDialect(default.DefaultDialect): retrieved server version information first. """ - if (self.server_version_info < (4, 1) and self._server_ansiquotes): + if self.server_version_info < (4, 1) and self._server_ansiquotes: # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = self.preparer(self, server_ansiquotes=False) else: @@ -2332,14 +2717,19 @@ class MySQLDialect(default.DefaultDialect): def _setup_parser(self, connection, table_name, schema=None, **kw): charset = self._connection_charset parser = self._tabledef_parser - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) - sql = self._show_create_table(connection, None, charset, - full_name=full_name) - if re.match(r'^CREATE (?:ALGORITHM)?.* VIEW', sql): + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql): # Adapt views to something table-like. - columns = self._describe_table(connection, None, charset, - full_name=full_name) + columns = self._describe_table( + connection, None, charset, full_name=full_name + ) sql = parser._describe_to_create(table_name, columns) return parser.parse(sql, charset) @@ -2356,17 +2746,18 @@ class MySQLDialect(default.DefaultDialect): # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html charset = self._connection_charset - row = self._compat_first(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) + row = self._compat_first( + connection.execute("SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset, + ) if not row: cs = 0 else: # 4.0.15 returns OFF or ON according to [ticket:489] # 3.23 doesn't, 4.0.27 doesn't.. - if row[1] == 'OFF': + if row[1] == "OFF": cs = 0 - elif row[1] == 'ON': + elif row[1] == "ON": cs = 1 else: cs = int(row[1]) @@ -2384,7 +2775,7 @@ class MySQLDialect(default.DefaultDialect): pass else: charset = self._connection_charset - rs = connection.execute('SHOW COLLATION') + rs = connection.execute("SHOW COLLATION") for row in self._compat_fetchall(rs, charset): collations[row[0]] = row[1] return collations @@ -2392,33 +2783,36 @@ class MySQLDialect(default.DefaultDialect): def _detect_sql_mode(self, connection): row = self._compat_first( connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), - charset=self._connection_charset) + charset=self._connection_charset, + ) if not row: util.warn( "Could not retrieve SQL_MODE; please ensure the " - "MySQL user has permissions to SHOW VARIABLES") - self._sql_mode = '' + "MySQL user has permissions to SHOW VARIABLES" + ) + self._sql_mode = "" else: - self._sql_mode = row[1] or '' + self._sql_mode = row[1] or "" def _detect_ansiquotes(self, connection): """Detect and adjust for the ANSI_QUOTES sql mode.""" mode = self._sql_mode if not mode: - mode = '' + mode = "" elif mode.isdigit(): mode_no = int(mode) - mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' + mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or "" - self._server_ansiquotes = 'ANSI_QUOTES' in mode + self._server_ansiquotes = "ANSI_QUOTES" in mode # as of MySQL 5.0.1 - self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode + self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode - def _show_create_table(self, connection, table, charset=None, - full_name=None): + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): """Run SHOW CREATE TABLE for a ``Table``.""" if full_name is None: @@ -2428,7 +2822,8 @@ class MySQLDialect(default.DefaultDialect): rp = None try: rp = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) except exc.DBAPIError as e: if self._extract_error_code(e.orig) == 1146: raise exc.NoSuchTableError(full_name) @@ -2441,8 +2836,7 @@ class MySQLDialect(default.DefaultDialect): return sql - def _describe_table(self, connection, table, charset=None, - full_name=None): + def _describe_table(self, connection, table, charset=None, full_name=None): """Run DESCRIBE for a ``Table`` and return processed rows.""" if full_name is None: @@ -2453,7 +2847,8 @@ class MySQLDialect(default.DefaultDialect): try: try: rp = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) except exc.DBAPIError as e: code = self._extract_error_code(e.orig) if code == 1146: @@ -2486,11 +2881,11 @@ class _DecodingRowProxy(object): # seem to come up in DDL queries. _encoding_compat = { - 'koi8r': 'koi8_r', - 'koi8u': 'koi8_u', - 'utf16': 'utf-16-be', # MySQL's uft16 is always bigendian - 'utf8mb4': 'utf8', # real utf8 - 'eucjpms': 'ujis', + "koi8r": "koi8_r", + "koi8u": "koi8_u", + "utf16": "utf-16-be", # MySQL's uft16 is always bigendian + "utf8mb4": "utf8", # real utf8 + "eucjpms": "ujis", } def __init__(self, rowproxy, charset): |