summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mysql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py1397
1 files changed, 896 insertions, 501 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 673d4b9ff..7b0d0618c 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -746,85 +746,340 @@ from ...engine import reflection
from ...engine import default
from ... import types as sqltypes
from ...util import topological
-from ...types import DATE, BOOLEAN, \
- BLOB, BINARY, VARBINARY
+from ...types import DATE, BOOLEAN, BLOB, BINARY, VARBINARY
from . import reflection as _reflection
-from .types import BIGINT, BIT, CHAR, DECIMAL, DATETIME, \
- DOUBLE, FLOAT, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, \
- MEDIUMTEXT, NCHAR, NUMERIC, NVARCHAR, REAL, SMALLINT, TEXT, TIME, \
- TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT, VARCHAR, YEAR
-from .types import _StringType, _IntegerType, _NumericType, \
- _FloatType, _MatchType
+from .types import (
+ BIGINT,
+ BIT,
+ CHAR,
+ DECIMAL,
+ DATETIME,
+ DOUBLE,
+ FLOAT,
+ INTEGER,
+ LONGBLOB,
+ LONGTEXT,
+ MEDIUMBLOB,
+ MEDIUMINT,
+ MEDIUMTEXT,
+ NCHAR,
+ NUMERIC,
+ NVARCHAR,
+ REAL,
+ SMALLINT,
+ TEXT,
+ TIME,
+ TIMESTAMP,
+ TINYBLOB,
+ TINYINT,
+ TINYTEXT,
+ VARCHAR,
+ YEAR,
+)
+from .types import (
+ _StringType,
+ _IntegerType,
+ _NumericType,
+ _FloatType,
+ _MatchType,
+)
from .enumerated import ENUM, SET
from .json import JSON, JSONIndexType, JSONPathType
RESERVED_WORDS = set(
- ['accessible', 'add', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
- 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
- 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check',
- 'collate', 'column', 'condition', 'constraint', 'continue', 'convert',
- 'create', 'cross', 'current_date', 'current_time', 'current_timestamp',
- 'current_user', 'cursor', 'database', 'databases', 'day_hour',
- 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal',
- 'declare', 'default', 'delayed', 'delete', 'desc', 'describe',
- 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop',
- 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists',
- 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8',
- 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group',
- 'having', 'high_priority', 'hour_microsecond', 'hour_minute',
- 'hour_second', 'if', 'ignore', 'in', 'index', 'infile', 'inner', 'inout',
- 'insensitive', 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8',
- 'integer', 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys',
- 'kill', 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines',
- 'load', 'localtime', 'localtimestamp', 'lock', 'long', 'longblob',
- 'longtext', 'loop', 'low_priority', 'master_ssl_verify_server_cert',
- 'match', 'mediumblob', 'mediumint', 'mediumtext', 'middleint',
- 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural',
- 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize',
- 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile',
- 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads',
- 'read_only', 'read_write', 'real', 'references', 'regexp', 'release',
- 'rename', 'repeat', 'replace', 'require', 'restrict', 'return',
- 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond',
- 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial',
- 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning',
- 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl',
- 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob',
- 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo',
- 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use',
- 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary',
- 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with',
-
- 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0
-
- 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1
-
- 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
- 'read_only', 'read_write', # 5.1
-
- 'general', 'ignore_server_ids', 'master_heartbeat_period', 'maxvalue',
- 'resignal', 'signal', 'slow', # 5.5
-
- 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot',
- 'partition', 'sql_after_gtids', 'sql_before_gtids', # 5.6
-
- 'generated', 'optimizer_costs', 'stored', 'virtual', # 5.7
-
- 'admin', 'cume_dist', 'empty', 'except', 'first_value', 'grouping',
- 'function', 'groups', 'json_table', 'last_value', 'nth_value',
- 'ntile', 'of', 'over', 'percent_rank', 'persist', 'persist_only',
- 'rank', 'recursive', 'role', 'row', 'rows', 'row_number', 'system',
- 'window', # 8.0
- ])
+ [
+ "accessible",
+ "add",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "as",
+ "asc",
+ "asensitive",
+ "before",
+ "between",
+ "bigint",
+ "binary",
+ "blob",
+ "both",
+ "by",
+ "call",
+ "cascade",
+ "case",
+ "change",
+ "char",
+ "character",
+ "check",
+ "collate",
+ "column",
+ "condition",
+ "constraint",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "databases",
+ "day_hour",
+ "day_microsecond",
+ "day_minute",
+ "day_second",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delayed",
+ "delete",
+ "desc",
+ "describe",
+ "deterministic",
+ "distinct",
+ "distinctrow",
+ "div",
+ "double",
+ "drop",
+ "dual",
+ "each",
+ "else",
+ "elseif",
+ "enclosed",
+ "escaped",
+ "exists",
+ "exit",
+ "explain",
+ "false",
+ "fetch",
+ "float",
+ "float4",
+ "float8",
+ "for",
+ "force",
+ "foreign",
+ "from",
+ "fulltext",
+ "grant",
+ "group",
+ "having",
+ "high_priority",
+ "hour_microsecond",
+ "hour_minute",
+ "hour_second",
+ "if",
+ "ignore",
+ "in",
+ "index",
+ "infile",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "int",
+ "int1",
+ "int2",
+ "int3",
+ "int4",
+ "int8",
+ "integer",
+ "interval",
+ "into",
+ "is",
+ "iterate",
+ "join",
+ "key",
+ "keys",
+ "kill",
+ "leading",
+ "leave",
+ "left",
+ "like",
+ "limit",
+ "linear",
+ "lines",
+ "load",
+ "localtime",
+ "localtimestamp",
+ "lock",
+ "long",
+ "longblob",
+ "longtext",
+ "loop",
+ "low_priority",
+ "master_ssl_verify_server_cert",
+ "match",
+ "mediumblob",
+ "mediumint",
+ "mediumtext",
+ "middleint",
+ "minute_microsecond",
+ "minute_second",
+ "mod",
+ "modifies",
+ "natural",
+ "not",
+ "no_write_to_binlog",
+ "null",
+ "numeric",
+ "on",
+ "optimize",
+ "option",
+ "optionally",
+ "or",
+ "order",
+ "out",
+ "outer",
+ "outfile",
+ "precision",
+ "primary",
+ "procedure",
+ "purge",
+ "range",
+ "read",
+ "reads",
+ "read_only",
+ "read_write",
+ "real",
+ "references",
+ "regexp",
+ "release",
+ "rename",
+ "repeat",
+ "replace",
+ "require",
+ "restrict",
+ "return",
+ "revoke",
+ "right",
+ "rlike",
+ "schema",
+ "schemas",
+ "second_microsecond",
+ "select",
+ "sensitive",
+ "separator",
+ "set",
+ "show",
+ "smallint",
+ "spatial",
+ "specific",
+ "sql",
+ "sqlexception",
+ "sqlstate",
+ "sqlwarning",
+ "sql_big_result",
+ "sql_calc_found_rows",
+ "sql_small_result",
+ "ssl",
+ "starting",
+ "straight_join",
+ "table",
+ "terminated",
+ "then",
+ "tinyblob",
+ "tinyint",
+ "tinytext",
+ "to",
+ "trailing",
+ "trigger",
+ "true",
+ "undo",
+ "union",
+ "unique",
+ "unlock",
+ "unsigned",
+ "update",
+ "usage",
+ "use",
+ "using",
+ "utc_date",
+ "utc_time",
+ "utc_timestamp",
+ "values",
+ "varbinary",
+ "varchar",
+ "varcharacter",
+ "varying",
+ "when",
+ "where",
+ "while",
+ "with",
+ "write",
+ "x509",
+ "xor",
+ "year_month",
+ "zerofill", # 5.0
+ "columns",
+ "fields",
+ "privileges",
+ "soname",
+ "tables", # 4.1
+ "accessible",
+ "linear",
+ "master_ssl_verify_server_cert",
+ "range",
+ "read_only",
+ "read_write", # 5.1
+ "general",
+ "ignore_server_ids",
+ "master_heartbeat_period",
+ "maxvalue",
+ "resignal",
+ "signal",
+ "slow", # 5.5
+ "get",
+ "io_after_gtids",
+ "io_before_gtids",
+ "master_bind",
+ "one_shot",
+ "partition",
+ "sql_after_gtids",
+ "sql_before_gtids", # 5.6
+ "generated",
+ "optimizer_costs",
+ "stored",
+ "virtual", # 5.7
+ "admin",
+ "cume_dist",
+ "empty",
+ "except",
+ "first_value",
+ "grouping",
+ "function",
+ "groups",
+ "json_table",
+ "last_value",
+ "nth_value",
+ "ntile",
+ "of",
+ "over",
+ "percent_rank",
+ "persist",
+ "persist_only",
+ "rank",
+ "recursive",
+ "role",
+ "row",
+ "rows",
+ "row_number",
+ "system",
+ "window", # 8.0
+ ]
+)
AUTOCOMMIT_RE = re.compile(
- r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)',
- re.I | re.UNICODE)
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)",
+ re.I | re.UNICODE,
+)
SET_RE = re.compile(
- r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w',
- re.I | re.UNICODE)
+ r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
+)
# old names
@@ -870,52 +1125,50 @@ colspecs = {
sqltypes.MatchType: _MatchType,
sqltypes.JSON: JSON,
sqltypes.JSON.JSONIndexType: JSONIndexType,
- sqltypes.JSON.JSONPathType: JSONPathType
-
+ sqltypes.JSON.JSONPathType: JSONPathType,
}
# Everything 3.23 through 5.1 excepting OpenGIS types.
ischema_names = {
- 'bigint': BIGINT,
- 'binary': BINARY,
- 'bit': BIT,
- 'blob': BLOB,
- 'boolean': BOOLEAN,
- 'char': CHAR,
- 'date': DATE,
- 'datetime': DATETIME,
- 'decimal': DECIMAL,
- 'double': DOUBLE,
- 'enum': ENUM,
- 'fixed': DECIMAL,
- 'float': FLOAT,
- 'int': INTEGER,
- 'integer': INTEGER,
- 'json': JSON,
- 'longblob': LONGBLOB,
- 'longtext': LONGTEXT,
- 'mediumblob': MEDIUMBLOB,
- 'mediumint': MEDIUMINT,
- 'mediumtext': MEDIUMTEXT,
- 'nchar': NCHAR,
- 'nvarchar': NVARCHAR,
- 'numeric': NUMERIC,
- 'set': SET,
- 'smallint': SMALLINT,
- 'text': TEXT,
- 'time': TIME,
- 'timestamp': TIMESTAMP,
- 'tinyblob': TINYBLOB,
- 'tinyint': TINYINT,
- 'tinytext': TINYTEXT,
- 'varbinary': VARBINARY,
- 'varchar': VARCHAR,
- 'year': YEAR,
+ "bigint": BIGINT,
+ "binary": BINARY,
+ "bit": BIT,
+ "blob": BLOB,
+ "boolean": BOOLEAN,
+ "char": CHAR,
+ "date": DATE,
+ "datetime": DATETIME,
+ "decimal": DECIMAL,
+ "double": DOUBLE,
+ "enum": ENUM,
+ "fixed": DECIMAL,
+ "float": FLOAT,
+ "int": INTEGER,
+ "integer": INTEGER,
+ "json": JSON,
+ "longblob": LONGBLOB,
+ "longtext": LONGTEXT,
+ "mediumblob": MEDIUMBLOB,
+ "mediumint": MEDIUMINT,
+ "mediumtext": MEDIUMTEXT,
+ "nchar": NCHAR,
+ "nvarchar": NVARCHAR,
+ "numeric": NUMERIC,
+ "set": SET,
+ "smallint": SMALLINT,
+ "text": TEXT,
+ "time": TIME,
+ "timestamp": TIMESTAMP,
+ "tinyblob": TINYBLOB,
+ "tinyint": TINYINT,
+ "tinytext": TINYTEXT,
+ "varbinary": VARBINARY,
+ "varchar": VARCHAR,
+ "year": YEAR,
}
class MySQLExecutionContext(default.DefaultExecutionContext):
-
def should_autocommit_text(self, statement):
return AUTOCOMMIT_RE.match(statement)
@@ -932,7 +1185,7 @@ class MySQLCompiler(compiler.SQLCompiler):
"""Overridden from base SQLCompiler value"""
extract_map = compiler.SQLCompiler.extract_map.copy()
- extract_map.update({'milliseconds': 'millisecond'})
+ extract_map.update({"milliseconds": "millisecond"})
def visit_random_func(self, fn, **kw):
return "rand%s" % self.function_argspec(fn)
@@ -943,12 +1196,14 @@ class MySQLCompiler(compiler.SQLCompiler):
def visit_json_getitem_op_binary(self, binary, operator, **kw):
return "JSON_EXTRACT(%s, %s)" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
return "JSON_EXTRACT(%s, %s)" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def visit_on_duplicate_key_update(self, on_duplicate, **kw):
if on_duplicate._parameter_ordering:
@@ -958,7 +1213,8 @@ class MySQLCompiler(compiler.SQLCompiler):
]
ordered_keys = set(parameter_ordering)
cols = [
- self.statement.table.c[key] for key in parameter_ordering
+ self.statement.table.c[key]
+ for key in parameter_ordering
if key in self.statement.table.c
] + [
c for c in self.statement.table.c if c.key not in ordered_keys
@@ -979,9 +1235,11 @@ class MySQLCompiler(compiler.SQLCompiler):
val = val._clone()
val.type = column.type
value_text = self.process(val.self_group(), use_schema=False)
- elif isinstance(val, elements.ColumnClause) \
- and val.table is on_duplicate.inserted_alias:
- value_text = 'VALUES(' + self.preparer.quote(column.name) + ')'
+ elif (
+ isinstance(val, elements.ColumnClause)
+ and val.table is on_duplicate.inserted_alias
+ ):
+ value_text = "VALUES(" + self.preparer.quote(column.name) + ")"
else:
value_text = self.process(val.self_group(), use_schema=False)
name_text = self.preparer.quote(column.name)
@@ -990,22 +1248,27 @@ class MySQLCompiler(compiler.SQLCompiler):
non_matching = set(on_duplicate.update) - set(c.key for c in cols)
if non_matching:
util.warn(
- 'Additional column names not matching '
- "any column keys in table '%s': %s" % (
+ "Additional column names not matching "
+ "any column keys in table '%s': %s"
+ % (
self.statement.table.name,
- (', '.join("'%s'" % c for c in non_matching))
+ (", ".join("'%s'" % c for c in non_matching)),
)
)
- return 'ON DUPLICATE KEY UPDATE ' + ', '.join(clauses)
+ return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses)
def visit_concat_op_binary(self, binary, operator, **kw):
- return "concat(%s, %s)" % (self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ return "concat(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
def visit_match_op_binary(self, binary, operator, **kw):
- return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \
- (self.process(binary.left, **kw), self.process(binary.right, **kw))
+ return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
def get_from_hint_text(self, table, text):
return text
@@ -1016,26 +1279,35 @@ class MySQLCompiler(compiler.SQLCompiler):
if isinstance(type_, sqltypes.TypeDecorator):
return self.visit_typeclause(typeclause, type_.impl, **kw)
elif isinstance(type_, sqltypes.Integer):
- if getattr(type_, 'unsigned', False):
- return 'UNSIGNED INTEGER'
+ if getattr(type_, "unsigned", False):
+ return "UNSIGNED INTEGER"
else:
- return 'SIGNED INTEGER'
+ return "SIGNED INTEGER"
elif isinstance(type_, sqltypes.TIMESTAMP):
- return 'DATETIME'
- elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime,
- sqltypes.Date, sqltypes.Time)):
+ return "DATETIME"
+ elif isinstance(
+ type_,
+ (
+ sqltypes.DECIMAL,
+ sqltypes.DateTime,
+ sqltypes.Date,
+ sqltypes.Time,
+ ),
+ ):
return self.dialect.type_compiler.process(type_)
- elif isinstance(type_, sqltypes.String) \
- and not isinstance(type_, (ENUM, SET)):
+ elif isinstance(type_, sqltypes.String) and not isinstance(
+ type_, (ENUM, SET)
+ ):
adapted = CHAR._adapt_string_for_cast(type_)
return self.dialect.type_compiler.process(adapted)
elif isinstance(type_, sqltypes._Binary):
- return 'BINARY'
+ return "BINARY"
elif isinstance(type_, sqltypes.JSON):
return "JSON"
elif isinstance(type_, sqltypes.NUMERIC):
- return self.dialect.type_compiler.process(
- type_).replace('NUMERIC', 'DECIMAL')
+ return self.dialect.type_compiler.process(type_).replace(
+ "NUMERIC", "DECIMAL"
+ )
else:
return None
@@ -1044,23 +1316,25 @@ class MySQLCompiler(compiler.SQLCompiler):
if not self.dialect._supports_cast:
util.warn(
"Current MySQL version does not support "
- "CAST; the CAST will be skipped.")
+ "CAST; the CAST will be skipped."
+ )
return self.process(cast.clause.self_group(), **kw)
type_ = self.process(cast.typeclause)
if type_ is None:
util.warn(
"Datatype %s does not support CAST on MySQL; "
- "the CAST will be skipped." %
- self.dialect.type_compiler.process(cast.typeclause.type))
+ "the CAST will be skipped."
+ % self.dialect.type_compiler.process(cast.typeclause.type)
+ )
return self.process(cast.clause.self_group(), **kw)
- return 'CAST(%s AS %s)' % (self.process(cast.clause, **kw), type_)
+ return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_)
def render_literal_value(self, value, type_):
value = super(MySQLCompiler, self).render_literal_value(value, type_)
if self.dialect._backslash_escapes:
- value = value.replace('\\', '\\\\')
+ value = value.replace("\\", "\\\\")
return value
# override native_boolean=False behavior here, as
@@ -1096,12 +1370,15 @@ class MySQLCompiler(compiler.SQLCompiler):
else:
join_type = " INNER JOIN "
- return ''.join(
- (self.process(join.left, asfrom=True, **kwargs),
- join_type,
- self.process(join.right, asfrom=True, **kwargs),
- " ON ",
- self.process(join.onclause, **kwargs)))
+ return "".join(
+ (
+ self.process(join.left, asfrom=True, **kwargs),
+ join_type,
+ self.process(join.right, asfrom=True, **kwargs),
+ " ON ",
+ self.process(join.onclause, **kwargs),
+ )
+ )
def for_update_clause(self, select, **kw):
if select._for_update_arg.read:
@@ -1118,11 +1395,13 @@ class MySQLCompiler(compiler.SQLCompiler):
# The latter is more readable for offsets but we're stuck with the
# former until we can refine dialects by server revision.
- limit_clause, offset_clause = select._limit_clause, \
- select._offset_clause
+ limit_clause, offset_clause = (
+ select._limit_clause,
+ select._offset_clause,
+ )
if limit_clause is None and offset_clause is None:
- return ''
+ return ""
elif offset_clause is not None:
# As suggested by the MySQL docs, need to apply an
# artificial limit if one wasn't provided
@@ -1134,35 +1413,38 @@ class MySQLCompiler(compiler.SQLCompiler):
# but also is consistent with the usage of the upper
# bound as part of MySQL's "syntax" for OFFSET with
# no LIMIT
- return ' \n LIMIT %s, %s' % (
+ return " \n LIMIT %s, %s" % (
self.process(offset_clause, **kw),
- "18446744073709551615")
+ "18446744073709551615",
+ )
else:
- return ' \n LIMIT %s, %s' % (
+ return " \n LIMIT %s, %s" % (
self.process(offset_clause, **kw),
- self.process(limit_clause, **kw))
+ self.process(limit_clause, **kw),
+ )
else:
# No offset provided, so just use the limit
- return ' \n LIMIT %s' % (self.process(limit_clause, **kw),)
+ return " \n LIMIT %s" % (self.process(limit_clause, **kw),)
def update_limit_clause(self, update_stmt):
- limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None)
+ limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
if limit:
return "LIMIT %s" % limit
else:
return None
- def update_tables_clause(self, update_stmt, from_table,
- extra_froms, **kw):
- return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw)
- for t in [from_table] + list(extra_froms))
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ return ", ".join(
+ t._compiler_dispatch(self, asfrom=True, **kw)
+ for t in [from_table] + list(extra_froms)
+ )
- def update_from_clause(self, update_stmt, from_table,
- extra_froms, from_hints, **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
return None
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
@@ -1171,24 +1453,27 @@ class MySQLCompiler(compiler.SQLCompiler):
self, asfrom=True, iscrud=True, ashint=ashint
)
- def delete_extra_from_clause(self, delete_stmt, from_table,
- extra_froms, from_hints, **kw):
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Render the DELETE .. USING clause specific to MySQL."""
- return "USING " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in [from_table] + extra_froms)
+ return "USING " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
def visit_empty_set_expr(self, element_types):
return (
"SELECT %(outer)s FROM (SELECT %(inner)s) "
- "as _empty_set WHERE 1!=1" % {
+ "as _empty_set WHERE 1!=1"
+ % {
"inner": ", ".join(
"1 AS _in_%s" % idx
- for idx, type_ in enumerate(element_types)),
+ for idx, type_ in enumerate(element_types)
+ ),
"outer": ", ".join(
- "_in_%s" % idx
- for idx, type_ in enumerate(element_types))
+ "_in_%s" % idx for idx, type_ in enumerate(element_types)
+ ),
}
)
@@ -1200,35 +1485,39 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
colspec = [
self.preparer.format_column(column),
self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ column.type, type_expression=column
+ ),
]
is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP)
if not column.nullable:
- colspec.append('NOT NULL')
+ colspec.append("NOT NULL")
# see: http://docs.sqlalchemy.org/en/latest/dialects/
# mysql.html#mysql_timestamp_null
elif column.nullable and is_timestamp:
- colspec.append('NULL')
+ colspec.append("NULL")
default = self.get_column_default_string(column)
if default is not None:
- colspec.append('DEFAULT ' + default)
+ colspec.append("DEFAULT " + default)
comment = column.comment
if comment is not None:
literal = self.sql_compiler.render_literal_value(
- comment, sqltypes.String())
- colspec.append('COMMENT ' + literal)
+ comment, sqltypes.String()
+ )
+ colspec.append("COMMENT " + literal)
- if column.table is not None \
- and column is column.table._autoincrement_column and \
- column.server_default is None:
- colspec.append('AUTO_INCREMENT')
+ if (
+ column.table is not None
+ and column is column.table._autoincrement_column
+ and column.server_default is None
+ ):
+ colspec.append("AUTO_INCREMENT")
- return ' '.join(colspec)
+ return " ".join(colspec)
def post_create_table(self, table):
"""Build table-level CREATE options like ENGINE and COLLATE."""
@@ -1236,76 +1525,94 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
table_opts = []
opts = dict(
- (
- k[len(self.dialect.name) + 1:].upper(),
- v
- )
+ (k[len(self.dialect.name) + 1 :].upper(), v)
for k, v in table.kwargs.items()
- if k.startswith('%s_' % self.dialect.name)
+ if k.startswith("%s_" % self.dialect.name)
)
if table.comment is not None:
- opts['COMMENT'] = table.comment
+ opts["COMMENT"] = table.comment
partition_options = [
- 'PARTITION_BY', 'PARTITIONS', 'SUBPARTITIONS',
- 'SUBPARTITION_BY'
+ "PARTITION_BY",
+ "PARTITIONS",
+ "SUBPARTITIONS",
+ "SUBPARTITION_BY",
]
nonpart_options = set(opts).difference(partition_options)
part_options = set(opts).intersection(partition_options)
- for opt in topological.sort([
- ('DEFAULT_CHARSET', 'COLLATE'),
- ('DEFAULT_CHARACTER_SET', 'COLLATE'),
- ], nonpart_options):
+ for opt in topological.sort(
+ [
+ ("DEFAULT_CHARSET", "COLLATE"),
+ ("DEFAULT_CHARACTER_SET", "COLLATE"),
+ ],
+ nonpart_options,
+ ):
arg = opts[opt]
if opt in _reflection._options_of_type_string:
arg = self.sql_compiler.render_literal_value(
- arg, sqltypes.String())
-
- if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY',
- 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET',
- 'DEFAULT_CHARSET',
- 'DEFAULT_COLLATE'):
- opt = opt.replace('_', ' ')
+ arg, sqltypes.String()
+ )
- joiner = '='
- if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET',
- 'CHARACTER SET', 'COLLATE'):
- joiner = ' '
+ if opt in (
+ "DATA_DIRECTORY",
+ "INDEX_DIRECTORY",
+ "DEFAULT_CHARACTER_SET",
+ "CHARACTER_SET",
+ "DEFAULT_CHARSET",
+ "DEFAULT_COLLATE",
+ ):
+ opt = opt.replace("_", " ")
+
+ joiner = "="
+ if opt in (
+ "TABLESPACE",
+ "DEFAULT CHARACTER SET",
+ "CHARACTER SET",
+ "COLLATE",
+ ):
+ joiner = " "
table_opts.append(joiner.join((opt, arg)))
- for opt in topological.sort([
- ('PARTITION_BY', 'PARTITIONS'),
- ('PARTITION_BY', 'SUBPARTITION_BY'),
- ('PARTITION_BY', 'SUBPARTITIONS'),
- ('PARTITIONS', 'SUBPARTITIONS'),
- ('PARTITIONS', 'SUBPARTITION_BY'),
- ('SUBPARTITION_BY', 'SUBPARTITIONS')
- ], part_options):
+ for opt in topological.sort(
+ [
+ ("PARTITION_BY", "PARTITIONS"),
+ ("PARTITION_BY", "SUBPARTITION_BY"),
+ ("PARTITION_BY", "SUBPARTITIONS"),
+ ("PARTITIONS", "SUBPARTITIONS"),
+ ("PARTITIONS", "SUBPARTITION_BY"),
+ ("SUBPARTITION_BY", "SUBPARTITIONS"),
+ ],
+ part_options,
+ ):
arg = opts[opt]
if opt in _reflection._options_of_type_string:
arg = self.sql_compiler.render_literal_value(
- arg, sqltypes.String())
+ arg, sqltypes.String()
+ )
- opt = opt.replace('_', ' ')
- joiner = ' '
+ opt = opt.replace("_", " ")
+ joiner = " "
table_opts.append(joiner.join((opt, arg)))
- return ' '.join(table_opts)
+ return " ".join(table_opts)
def visit_create_index(self, create, **kw):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
table = preparer.format_table(index.table)
- columns = [self.sql_compiler.process(expr, include_table=False,
- literal_binds=True)
- for expr in index.expressions]
+ columns = [
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ]
name = self._prepared_index_name(index)
@@ -1313,53 +1620,54 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
if index.unique:
text += "UNIQUE "
- index_prefix = index.kwargs.get('mysql_prefix', None)
+ index_prefix = index.kwargs.get("mysql_prefix", None)
if index_prefix:
- text += index_prefix + ' '
+ text += index_prefix + " "
text += "INDEX %s ON %s " % (name, table)
- length = index.dialect_options['mysql']['length']
+ length = index.dialect_options["mysql"]["length"]
if length is not None:
if isinstance(length, dict):
# length value can be a (column_name --> integer value)
# mapping specifying the prefix length for each column of the
# index
- columns = ', '.join(
- '%s(%d)' % (expr, length[col.name]) if col.name in length
- else
- (
- '%s(%d)' % (expr, length[expr]) if expr in length
- else '%s' % expr
+ columns = ", ".join(
+ "%s(%d)" % (expr, length[col.name])
+ if col.name in length
+ else (
+ "%s(%d)" % (expr, length[expr])
+ if expr in length
+ else "%s" % expr
)
for col, expr in zip(index.expressions, columns)
)
else:
# or can be an integer value specifying the same
# prefix length for all columns of the index
- columns = ', '.join(
- '%s(%d)' % (col, length)
- for col in columns
+ columns = ", ".join(
+ "%s(%d)" % (col, length) for col in columns
)
else:
- columns = ', '.join(columns)
- text += '(%s)' % columns
+ columns = ", ".join(columns)
+ text += "(%s)" % columns
- parser = index.dialect_options['mysql']['with_parser']
+ parser = index.dialect_options["mysql"]["with_parser"]
if parser is not None:
- text += " WITH PARSER %s" % (parser, )
+ text += " WITH PARSER %s" % (parser,)
- using = index.dialect_options['mysql']['using']
+ using = index.dialect_options["mysql"]["using"]
if using is not None:
text += " USING %s" % (preparer.quote(using))
return text
def visit_primary_key_constraint(self, constraint):
- text = super(MySQLDDLCompiler, self).\
- visit_primary_key_constraint(constraint)
- using = constraint.dialect_options['mysql']['using']
+ text = super(MySQLDDLCompiler, self).visit_primary_key_constraint(
+ constraint
+ )
+ using = constraint.dialect_options["mysql"]["using"]
if using:
text += " USING %s" % (self.preparer.quote(using))
return text
@@ -1368,9 +1676,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
index = drop.element
return "\nDROP INDEX %s ON %s" % (
- self._prepared_index_name(index,
- include_schema=False),
- self.preparer.format_table(index.table))
+ self._prepared_index_name(index, include_schema=False),
+ self.preparer.format_table(index.table),
+ )
def visit_drop_constraint(self, drop):
constraint = drop.element
@@ -1386,29 +1694,33 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
else:
qual = ""
const = self.preparer.format_constraint(constraint)
- return "ALTER TABLE %s DROP %s%s" % \
- (self.preparer.format_table(constraint.table),
- qual, const)
+ return "ALTER TABLE %s DROP %s%s" % (
+ self.preparer.format_table(constraint.table),
+ qual,
+ const,
+ )
def define_constraint_match(self, constraint):
if constraint.match is not None:
raise exc.CompileError(
"MySQL ignores the 'MATCH' keyword while at the same time "
- "causes ON UPDATE/ON DELETE clauses to be ignored.")
+ "causes ON UPDATE/ON DELETE clauses to be ignored."
+ )
return ""
def visit_set_table_comment(self, create):
return "ALTER TABLE %s COMMENT %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_set_column_comment(self, create):
return "ALTER TABLE %s CHANGE %s %s" % (
self.preparer.format_table(create.element.table),
self.preparer.format_column(create.element),
- self.get_column_specification(create.element)
+ self.get_column_specification(create.element),
)
@@ -1420,9 +1732,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
return spec
if type_.unsigned:
- spec += ' UNSIGNED'
+ spec += " UNSIGNED"
if type_.zerofill:
- spec += ' ZEROFILL'
+ spec += " ZEROFILL"
return spec
def _extend_string(self, type_, defaults, spec):
@@ -1434,28 +1746,30 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def attr(name):
return getattr(type_, name, defaults.get(name))
- if attr('charset'):
- charset = 'CHARACTER SET %s' % attr('charset')
- elif attr('ascii'):
- charset = 'ASCII'
- elif attr('unicode'):
- charset = 'UNICODE'
+ if attr("charset"):
+ charset = "CHARACTER SET %s" % attr("charset")
+ elif attr("ascii"):
+ charset = "ASCII"
+ elif attr("unicode"):
+ charset = "UNICODE"
else:
charset = None
- if attr('collation'):
- collation = 'COLLATE %s' % type_.collation
- elif attr('binary'):
- collation = 'BINARY'
+ if attr("collation"):
+ collation = "COLLATE %s" % type_.collation
+ elif attr("binary"):
+ collation = "BINARY"
else:
collation = None
- if attr('national'):
+ if attr("national"):
# NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
- return ' '.join([c for c in ('NATIONAL', spec, collation)
- if c is not None])
- return ' '.join([c for c in (spec, charset, collation)
- if c is not None])
+ return " ".join(
+ [c for c in ("NATIONAL", spec, collation) if c is not None]
+ )
+ return " ".join(
+ [c for c in (spec, charset, collation) if c is not None]
+ )
def _mysql_type(self, type_):
return isinstance(type_, (_StringType, _NumericType))
@@ -1464,95 +1778,113 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
if type_.precision is None:
return self._extend_numeric(type_, "NUMERIC")
elif type_.scale is None:
- return self._extend_numeric(type_,
- "NUMERIC(%(precision)s)" %
- {'precision': type_.precision})
+ return self._extend_numeric(
+ type_,
+ "NUMERIC(%(precision)s)" % {"precision": type_.precision},
+ )
else:
- return self._extend_numeric(type_,
- "NUMERIC(%(precision)s, %(scale)s)" %
- {'precision': type_.precision,
- 'scale': type_.scale})
+ return self._extend_numeric(
+ type_,
+ "NUMERIC(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return self._extend_numeric(type_, "DECIMAL")
elif type_.scale is None:
- return self._extend_numeric(type_,
- "DECIMAL(%(precision)s)" %
- {'precision': type_.precision})
+ return self._extend_numeric(
+ type_,
+ "DECIMAL(%(precision)s)" % {"precision": type_.precision},
+ )
else:
- return self._extend_numeric(type_,
- "DECIMAL(%(precision)s, %(scale)s)" %
- {'precision': type_.precision,
- 'scale': type_.scale})
+ return self._extend_numeric(
+ type_,
+ "DECIMAL(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
def visit_DOUBLE(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
- return self._extend_numeric(type_,
- "DOUBLE(%(precision)s, %(scale)s)" %
- {'precision': type_.precision,
- 'scale': type_.scale})
+ return self._extend_numeric(
+ type_,
+ "DOUBLE(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
else:
- return self._extend_numeric(type_, 'DOUBLE')
+ return self._extend_numeric(type_, "DOUBLE")
def visit_REAL(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
- return self._extend_numeric(type_,
- "REAL(%(precision)s, %(scale)s)" %
- {'precision': type_.precision,
- 'scale': type_.scale})
+ return self._extend_numeric(
+ type_,
+ "REAL(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
else:
- return self._extend_numeric(type_, 'REAL')
+ return self._extend_numeric(type_, "REAL")
def visit_FLOAT(self, type_, **kw):
- if self._mysql_type(type_) and \
- type_.scale is not None and \
- type_.precision is not None:
+ if (
+ self._mysql_type(type_)
+ and type_.scale is not None
+ and type_.precision is not None
+ ):
return self._extend_numeric(
- type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale))
+ type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)
+ )
elif type_.precision is not None:
- return self._extend_numeric(type_,
- "FLOAT(%s)" % (type_.precision,))
+ return self._extend_numeric(
+ type_, "FLOAT(%s)" % (type_.precision,)
+ )
else:
return self._extend_numeric(type_, "FLOAT")
def visit_INTEGER(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
- type_, "INTEGER(%(display_width)s)" %
- {'display_width': type_.display_width})
+ type_,
+ "INTEGER(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
else:
return self._extend_numeric(type_, "INTEGER")
def visit_BIGINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
- type_, "BIGINT(%(display_width)s)" %
- {'display_width': type_.display_width})
+ type_,
+ "BIGINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
else:
return self._extend_numeric(type_, "BIGINT")
def visit_MEDIUMINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
- type_, "MEDIUMINT(%(display_width)s)" %
- {'display_width': type_.display_width})
+ type_,
+ "MEDIUMINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
else:
return self._extend_numeric(type_, "MEDIUMINT")
def visit_TINYINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
- return self._extend_numeric(type_,
- "TINYINT(%s)" % type_.display_width)
+ return self._extend_numeric(
+ type_, "TINYINT(%s)" % type_.display_width
+ )
else:
return self._extend_numeric(type_, "TINYINT")
def visit_SMALLINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
- return self._extend_numeric(type_,
- "SMALLINT(%(display_width)s)" %
- {'display_width': type_.display_width}
- )
+ return self._extend_numeric(
+ type_,
+ "SMALLINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
else:
return self._extend_numeric(type_, "SMALLINT")
@@ -1563,7 +1895,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
return "BIT"
def visit_DATETIME(self, type_, **kw):
- if getattr(type_, 'fsp', None):
+ if getattr(type_, "fsp", None):
return "DATETIME(%d)" % type_.fsp
else:
return "DATETIME"
@@ -1572,13 +1904,13 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
return "DATE"
def visit_TIME(self, type_, **kw):
- if getattr(type_, 'fsp', None):
+ if getattr(type_, "fsp", None):
return "TIME(%d)" % type_.fsp
else:
return "TIME"
def visit_TIMESTAMP(self, type_, **kw):
- if getattr(type_, 'fsp', None):
+ if getattr(type_, "fsp", None):
return "TIMESTAMP(%d)" % type_.fsp
else:
return "TIMESTAMP"
@@ -1606,17 +1938,17 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def visit_VARCHAR(self, type_, **kw):
if type_.length:
- return self._extend_string(
- type_, {}, "VARCHAR(%d)" % type_.length)
+ return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
else:
raise exc.CompileError(
- "VARCHAR requires a length on dialect %s" %
- self.dialect.name)
+ "VARCHAR requires a length on dialect %s" % self.dialect.name
+ )
def visit_CHAR(self, type_, **kw):
if type_.length:
- return self._extend_string(type_, {}, "CHAR(%(length)s)" %
- {'length': type_.length})
+ return self._extend_string(
+ type_, {}, "CHAR(%(length)s)" % {"length": type_.length}
+ )
else:
return self._extend_string(type_, {}, "CHAR")
@@ -1625,22 +1957,26 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
# of "NVARCHAR".
if type_.length:
return self._extend_string(
- type_, {'national': True},
- "VARCHAR(%(length)s)" % {'length': type_.length})
+ type_,
+ {"national": True},
+ "VARCHAR(%(length)s)" % {"length": type_.length},
+ )
else:
raise exc.CompileError(
- "NVARCHAR requires a length on dialect %s" %
- self.dialect.name)
+ "NVARCHAR requires a length on dialect %s" % self.dialect.name
+ )
def visit_NCHAR(self, type_, **kw):
# We'll actually generate the equiv.
# "NATIONAL CHAR" instead of "NCHAR".
if type_.length:
return self._extend_string(
- type_, {'national': True},
- "CHAR(%(length)s)" % {'length': type_.length})
+ type_,
+ {"national": True},
+ "CHAR(%(length)s)" % {"length": type_.length},
+ )
else:
- return self._extend_string(type_, {'national': True}, "CHAR")
+ return self._extend_string(type_, {"national": True}, "CHAR")
def visit_VARBINARY(self, type_, **kw):
return "VARBINARY(%d)" % type_.length
@@ -1676,17 +2012,19 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
quoted_enums = []
for e in enumerated_values:
quoted_enums.append("'%s'" % e.replace("'", "''"))
- return self._extend_string(type_, {}, "%s(%s)" % (
- name, ",".join(quoted_enums))
+ return self._extend_string(
+ type_, {}, "%s(%s)" % (name, ",".join(quoted_enums))
)
def visit_ENUM(self, type_, **kw):
- return self._visit_enumerated_values("ENUM", type_,
- type_._enumerated_values)
+ return self._visit_enumerated_values(
+ "ENUM", type_, type_._enumerated_values
+ )
def visit_SET(self, type_, **kw):
- return self._visit_enumerated_values("SET", type_,
- type_._enumerated_values)
+ return self._visit_enumerated_values(
+ "SET", type_, type_._enumerated_values
+ )
def visit_BOOLEAN(self, type, **kw):
return "BOOL"
@@ -1703,9 +2041,8 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
quote = '"'
super(MySQLIdentifierPreparer, self).__init__(
- dialect,
- initial_quote=quote,
- escape_quote=quote)
+ dialect, initial_quote=quote, escape_quote=quote
+ )
def _quote_free_identifiers(self, *ids):
"""Unilaterally identifier-quote any number of strings."""
@@ -1719,7 +2056,7 @@ class MySQLDialect(default.DefaultDialect):
Not used directly in application code.
"""
- name = 'mysql'
+ name = "mysql"
supports_alter = True
# MySQL has no true "boolean" type; we
@@ -1738,7 +2075,7 @@ class MySQLDialect(default.DefaultDialect):
supports_comments = True
inline_comments = True
- default_paramstyle = 'format'
+ default_paramstyle = "format"
colspecs = colspecs
cte_follows_insert = True
@@ -1756,26 +2093,28 @@ class MySQLDialect(default.DefaultDialect):
_server_ansiquotes = False
construct_arguments = [
- (sa_schema.Table, {
- "*": None
- }),
- (sql.Update, {
- "limit": None
- }),
- (sa_schema.PrimaryKeyConstraint, {
- "using": None
- }),
- (sa_schema.Index, {
- "using": None,
- "length": None,
- "prefix": None,
- "with_parser": None
- })
+ (sa_schema.Table, {"*": None}),
+ (sql.Update, {"limit": None}),
+ (sa_schema.PrimaryKeyConstraint, {"using": None}),
+ (
+ sa_schema.Index,
+ {
+ "using": None,
+ "length": None,
+ "prefix": None,
+ "with_parser": None,
+ },
+ ),
]
- def __init__(self, isolation_level=None, json_serializer=None,
- json_deserializer=None, **kwargs):
- kwargs.pop('use_ansiquotes', None) # legacy
+ def __init__(
+ self,
+ isolation_level=None,
+ json_serializer=None,
+ json_deserializer=None,
+ **kwargs
+ ):
+ kwargs.pop("use_ansiquotes", None) # legacy
default.DefaultDialect.__init__(self, **kwargs)
self.isolation_level = isolation_level
self._json_serializer = json_serializer
@@ -1783,22 +2122,30 @@ class MySQLDialect(default.DefaultDialect):
def on_connect(self):
if self.isolation_level is not None:
+
def connect(conn):
self.set_isolation_level(conn, self.isolation_level)
+
return connect
else:
return None
- _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
- 'READ COMMITTED', 'REPEATABLE READ'])
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ ]
+ )
def set_isolation_level(self, connection, level):
- level = level.replace('_', ' ')
+ level = level.replace("_", " ")
# adjust for ConnectionFairy being present
# allows attribute set e.g. "connection.autocommit = True"
# to work properly
- if hasattr(connection, 'connection'):
+ if hasattr(connection, "connection"):
connection = connection.connection
self._set_isolation_level(connection, level)
@@ -1807,8 +2154,8 @@ class MySQLDialect(default.DefaultDialect):
if level not in self._isolation_lookup:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
- "Valid isolation levels for %s are %s" %
- (level, self.name, ", ".join(self._isolation_lookup))
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
)
cursor = connection.cursor()
cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level)
@@ -1818,9 +2165,9 @@ class MySQLDialect(default.DefaultDialect):
def get_isolation_level(self, connection):
cursor = connection.cursor()
if self._is_mysql and self.server_version_info >= (5, 7, 20):
- cursor.execute('SELECT @@transaction_isolation')
+ cursor.execute("SELECT @@transaction_isolation")
else:
- cursor.execute('SELECT @@tx_isolation')
+ cursor.execute("SELECT @@tx_isolation")
val = cursor.fetchone()[0]
cursor.close()
if util.py3k and isinstance(val, bytes):
@@ -1840,7 +2187,7 @@ class MySQLDialect(default.DefaultDialect):
val = val.decode()
version = []
- r = re.compile(r'[.\-]')
+ r = re.compile(r"[.\-]")
for n in r.split(val):
try:
version.append(int(n))
@@ -1885,29 +2232,38 @@ class MySQLDialect(default.DefaultDialect):
connection.execute(sql.text("XA END :xid"), xid=xid)
connection.execute(sql.text("XA PREPARE :xid"), xid=xid)
- def do_rollback_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if not is_prepared:
connection.execute(sql.text("XA END :xid"), xid=xid)
connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid)
- def do_commit_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
connection.execute(sql.text("XA COMMIT :xid"), xid=xid)
def do_recover_twophase(self, connection):
resultset = connection.execute("XA RECOVER")
- return [row['data'][0:row['gtrid_length']] for row in resultset]
+ return [row["data"][0 : row["gtrid_length"]] for row in resultset]
def is_disconnect(self, e, connection, cursor):
- if isinstance(e, (self.dbapi.OperationalError,
- self.dbapi.ProgrammingError)):
- return self._extract_error_code(e) in \
- (2006, 2013, 2014, 2045, 2055)
+ if isinstance(
+ e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
+ ):
+ return self._extract_error_code(e) in (
+ 2006,
+ 2013,
+ 2014,
+ 2045,
+ 2055,
+ )
elif isinstance(
- e, (self.dbapi.InterfaceError, self.dbapi.InternalError)):
+ e, (self.dbapi.InterfaceError, self.dbapi.InternalError)
+ ):
# if underlying connection is closed,
# this is the error you get
return "(0, '')" in str(e)
@@ -1944,7 +2300,7 @@ class MySQLDialect(default.DefaultDialect):
raise NotImplementedError()
def _get_default_schema_name(self, connection):
- return connection.execute('SELECT DATABASE()').scalar()
+ return connection.execute("SELECT DATABASE()").scalar()
def has_table(self, connection, table_name, schema=None):
# SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly
@@ -1957,15 +2313,19 @@ class MySQLDialect(default.DefaultDialect):
# full_name = self.identifier_preparer.format_table(table,
# use_schema=True)
- full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
- schema, table_name))
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(
+ schema, table_name
+ )
+ )
st = "DESCRIBE %s" % full_name
rs = None
try:
try:
rs = connection.execution_options(
- skip_user_error_events=True).execute(st)
+ skip_user_error_events=True
+ ).execute(st)
have = rs.fetchone() is not None
rs.close()
return have
@@ -1986,12 +2346,13 @@ class MySQLDialect(default.DefaultDialect):
# if ansiquotes == True, build a new IdentifierPreparer
# with the new setting
self.identifier_preparer = self.preparer(
- self, server_ansiquotes=self._server_ansiquotes)
+ self, server_ansiquotes=self._server_ansiquotes
+ )
default.DefaultDialect.initialize(self, connection)
self._needs_correct_for_88718 = (
- not self._is_mariadb and self.server_version_info >= (8, )
+ not self._is_mariadb and self.server_version_info >= (8,)
)
self._warn_for_known_db_issues()
@@ -2007,20 +2368,23 @@ class MySQLDialect(default.DefaultDialect):
"additional issue prevents proper migrations of columns "
"with CHECK constraints (MDEV-11114). Please upgrade to "
"MariaDB 10.2.9 or greater, or use the MariaDB 10.1 "
- "series, to avoid these issues." % (mdb_version, ))
+ "series, to avoid these issues." % (mdb_version,)
+ )
@property
def _is_mariadb(self):
- return 'MariaDB' in self.server_version_info
+ return "MariaDB" in self.server_version_info
@property
def _is_mysql(self):
- return 'MariaDB' not in self.server_version_info
+ return "MariaDB" not in self.server_version_info
@property
def _is_mariadb_102(self):
- return self._is_mariadb and \
- self._mariadb_normalized_version_info > (10, 2)
+ return self._is_mariadb and self._mariadb_normalized_version_info > (
+ 10,
+ 2,
+ )
@property
def _mariadb_normalized_version_info(self):
@@ -2028,15 +2392,17 @@ class MySQLDialect(default.DefaultDialect):
# the string "5.5"; now that we use @@version we no longer see this.
if self._is_mariadb:
- idx = self.server_version_info.index('MariaDB')
- return self.server_version_info[idx - 3: idx]
+ idx = self.server_version_info.index("MariaDB")
+ return self.server_version_info[idx - 3 : idx]
else:
return self.server_version_info
@property
def _supports_cast(self):
- return self.server_version_info is None or \
- self.server_version_info >= (4, 0, 2)
+ return (
+ self.server_version_info is None
+ or self.server_version_info >= (4, 0, 2)
+ )
@reflection.cache
def get_schema_names(self, connection, **kw):
@@ -2054,18 +2420,23 @@ class MySQLDialect(default.DefaultDialect):
charset = self._connection_charset
if self.server_version_info < (5, 0, 2):
rp = connection.execute(
- "SHOW TABLES FROM %s" %
- self.identifier_preparer.quote_identifier(current_schema))
- return [row[0] for
- row in self._compat_fetchall(rp, charset=charset)]
+ "SHOW TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(current_schema)
+ )
+ return [
+ row[0] for row in self._compat_fetchall(rp, charset=charset)
+ ]
else:
rp = connection.execute(
- "SHOW FULL TABLES FROM %s" %
- self.identifier_preparer.quote_identifier(current_schema))
+ "SHOW FULL TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(current_schema)
+ )
- return [row[0]
- for row in self._compat_fetchall(rp, charset=charset)
- if row[1] == 'BASE TABLE']
+ return [
+ row[0]
+ for row in self._compat_fetchall(rp, charset=charset)
+ if row[1] == "BASE TABLE"
+ ]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
@@ -2077,72 +2448,77 @@ class MySQLDialect(default.DefaultDialect):
return self.get_table_names(connection, schema)
charset = self._connection_charset
rp = connection.execute(
- "SHOW FULL TABLES FROM %s" %
- self.identifier_preparer.quote_identifier(schema))
- return [row[0]
- for row in self._compat_fetchall(rp, charset=charset)
- if row[1] in ('VIEW', 'SYSTEM VIEW')]
+ "SHOW FULL TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(schema)
+ )
+ return [
+ row[0]
+ for row in self._compat_fetchall(rp, charset=charset)
+ if row[1] in ("VIEW", "SYSTEM VIEW")
+ ]
@reflection.cache
def get_table_options(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
return parsed_state.table_options
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
return parsed_state.columns
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
for key in parsed_state.keys:
- if key['type'] == 'PRIMARY':
+ if key["type"] == "PRIMARY":
# There can be only one.
- cols = [s[0] for s in key['columns']]
- return {'constrained_columns': cols, 'name': None}
- return {'constrained_columns': [], 'name': None}
+ cols = [s[0] for s in key["columns"]]
+ return {"constrained_columns": cols, "name": None}
+ return {"constrained_columns": [], "name": None}
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
default_schema = None
fkeys = []
for spec in parsed_state.fk_constraints:
- ref_name = spec['table'][-1]
- ref_schema = len(spec['table']) > 1 and \
- spec['table'][-2] or schema
+ ref_name = spec["table"][-1]
+ ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema
if not ref_schema:
if default_schema is None:
- default_schema = \
- connection.dialect.default_schema_name
+ default_schema = connection.dialect.default_schema_name
if schema == default_schema:
ref_schema = schema
- loc_names = spec['local']
- ref_names = spec['foreign']
+ loc_names = spec["local"]
+ ref_names = spec["foreign"]
con_kw = {}
- for opt in ('onupdate', 'ondelete'):
+ for opt in ("onupdate", "ondelete"):
if spec.get(opt, False):
con_kw[opt] = spec[opt]
fkey_d = {
- 'name': spec['name'],
- 'constrained_columns': loc_names,
- 'referred_schema': ref_schema,
- 'referred_table': ref_name,
- 'referred_columns': ref_names,
- 'options': con_kw
+ "name": spec["name"],
+ "constrained_columns": loc_names,
+ "referred_schema": ref_schema,
+ "referred_table": ref_name,
+ "referred_columns": ref_names,
+ "options": con_kw,
}
fkeys.append(fkey_d)
@@ -2172,25 +2548,26 @@ class MySQLDialect(default.DefaultDialect):
default_schema_name = connection.dialect.default_schema_name
col_tuples = [
(
- lower(rec['referred_schema'] or default_schema_name),
- lower(rec['referred_table']),
- col_name
+ lower(rec["referred_schema"] or default_schema_name),
+ lower(rec["referred_table"]),
+ col_name,
)
for rec in fkeys
- for col_name in rec['referred_columns']
+ for col_name in rec["referred_columns"]
]
if col_tuples:
correct_for_wrong_fk_case = connection.execute(
- sql.text("""
+ sql.text(
+ """
select table_schema, table_name, column_name
from information_schema.columns
where (table_schema, table_name, lower(column_name)) in
:table_data;
- """).bindparams(
- sql.bindparam("table_data", expanding=True)
- ), table_data=col_tuples
+ """
+ ).bindparams(sql.bindparam("table_data", expanding=True)),
+ table_data=col_tuples,
)
# in casing=0, table name and schema name come back in their
@@ -2208,109 +2585,117 @@ class MySQLDialect(default.DefaultDialect):
d[(lower(schema), lower(tname))][cname.lower()] = cname
for fkey in fkeys:
- fkey['referred_columns'] = [
+ fkey["referred_columns"] = [
d[
(
lower(
- fkey['referred_schema'] or
- default_schema_name),
- lower(fkey['referred_table'])
+ fkey["referred_schema"] or default_schema_name
+ ),
+ lower(fkey["referred_table"]),
)
][col.lower()]
- for col in fkey['referred_columns']
+ for col in fkey["referred_columns"]
]
@reflection.cache
- def get_check_constraints(
- self, connection, table_name, schema=None, **kw):
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
return [
- {"name": spec['name'], "sqltext": spec['sqltext']}
+ {"name": spec["name"], "sqltext": spec["sqltext"]}
for spec in parsed_state.ck_constraints
]
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
- return {"text": parsed_state.table_options.get('mysql_comment', None)}
+ connection, table_name, schema, **kw
+ )
+ return {"text": parsed_state.table_options.get("mysql_comment", None)}
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
indexes = []
for spec in parsed_state.keys:
dialect_options = {}
unique = False
- flavor = spec['type']
- if flavor == 'PRIMARY':
+ flavor = spec["type"]
+ if flavor == "PRIMARY":
continue
- if flavor == 'UNIQUE':
+ if flavor == "UNIQUE":
unique = True
- elif flavor in ('FULLTEXT', 'SPATIAL'):
+ elif flavor in ("FULLTEXT", "SPATIAL"):
dialect_options["mysql_prefix"] = flavor
elif flavor is None:
pass
else:
self.logger.info(
- "Converting unknown KEY type %s to a plain KEY", flavor)
+ "Converting unknown KEY type %s to a plain KEY", flavor
+ )
pass
- if spec['parser']:
- dialect_options['mysql_with_parser'] = spec['parser']
+ if spec["parser"]:
+ dialect_options["mysql_with_parser"] = spec["parser"]
index_d = {}
if dialect_options:
index_d["dialect_options"] = dialect_options
- index_d['name'] = spec['name']
- index_d['column_names'] = [s[0] for s in spec['columns']]
- index_d['unique'] = unique
+ index_d["name"] = spec["name"]
+ index_d["column_names"] = [s[0] for s in spec["columns"]]
+ index_d["unique"] = unique
if flavor:
- index_d['type'] = flavor
+ index_d["type"] = flavor
indexes.append(index_d)
return indexes
@reflection.cache
- def get_unique_constraints(self, connection, table_name,
- schema=None, **kw):
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
return [
{
- 'name': key['name'],
- 'column_names': [col[0] for col in key['columns']],
- 'duplicates_index': key['name'],
+ "name": key["name"],
+ "column_names": [col[0] for col in key["columns"]],
+ "duplicates_index": key["name"],
}
for key in parsed_state.keys
- if key['type'] == 'UNIQUE'
+ if key["type"] == "UNIQUE"
]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
charset = self._connection_charset
- full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
- schema, view_name))
- sql = self._show_create_table(connection, None, charset,
- full_name=full_name)
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(schema, view_name)
+ )
+ sql = self._show_create_table(
+ connection, None, charset, full_name=full_name
+ )
return sql
- def _parsed_state_or_create(self, connection, table_name,
- schema=None, **kw):
+ def _parsed_state_or_create(
+ self, connection, table_name, schema=None, **kw
+ ):
return self._setup_parser(
connection,
table_name,
schema,
- info_cache=kw.get('info_cache', None)
+ info_cache=kw.get("info_cache", None),
)
@util.memoized_property
@@ -2321,7 +2706,7 @@ class MySQLDialect(default.DefaultDialect):
retrieved server version information first.
"""
- if (self.server_version_info < (4, 1) and self._server_ansiquotes):
+ if self.server_version_info < (4, 1) and self._server_ansiquotes:
# ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
preparer = self.preparer(self, server_ansiquotes=False)
else:
@@ -2332,14 +2717,19 @@ class MySQLDialect(default.DefaultDialect):
def _setup_parser(self, connection, table_name, schema=None, **kw):
charset = self._connection_charset
parser = self._tabledef_parser
- full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
- schema, table_name))
- sql = self._show_create_table(connection, None, charset,
- full_name=full_name)
- if re.match(r'^CREATE (?:ALGORITHM)?.* VIEW', sql):
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(
+ schema, table_name
+ )
+ )
+ sql = self._show_create_table(
+ connection, None, charset, full_name=full_name
+ )
+ if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql):
# Adapt views to something table-like.
- columns = self._describe_table(connection, None, charset,
- full_name=full_name)
+ columns = self._describe_table(
+ connection, None, charset, full_name=full_name
+ )
sql = parser._describe_to_create(table_name, columns)
return parser.parse(sql, charset)
@@ -2356,17 +2746,18 @@ class MySQLDialect(default.DefaultDialect):
# http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
charset = self._connection_charset
- row = self._compat_first(connection.execute(
- "SHOW VARIABLES LIKE 'lower_case_table_names'"),
- charset=charset)
+ row = self._compat_first(
+ connection.execute("SHOW VARIABLES LIKE 'lower_case_table_names'"),
+ charset=charset,
+ )
if not row:
cs = 0
else:
# 4.0.15 returns OFF or ON according to [ticket:489]
# 3.23 doesn't, 4.0.27 doesn't..
- if row[1] == 'OFF':
+ if row[1] == "OFF":
cs = 0
- elif row[1] == 'ON':
+ elif row[1] == "ON":
cs = 1
else:
cs = int(row[1])
@@ -2384,7 +2775,7 @@ class MySQLDialect(default.DefaultDialect):
pass
else:
charset = self._connection_charset
- rs = connection.execute('SHOW COLLATION')
+ rs = connection.execute("SHOW COLLATION")
for row in self._compat_fetchall(rs, charset):
collations[row[0]] = row[1]
return collations
@@ -2392,33 +2783,36 @@ class MySQLDialect(default.DefaultDialect):
def _detect_sql_mode(self, connection):
row = self._compat_first(
connection.execute("SHOW VARIABLES LIKE 'sql_mode'"),
- charset=self._connection_charset)
+ charset=self._connection_charset,
+ )
if not row:
util.warn(
"Could not retrieve SQL_MODE; please ensure the "
- "MySQL user has permissions to SHOW VARIABLES")
- self._sql_mode = ''
+ "MySQL user has permissions to SHOW VARIABLES"
+ )
+ self._sql_mode = ""
else:
- self._sql_mode = row[1] or ''
+ self._sql_mode = row[1] or ""
def _detect_ansiquotes(self, connection):
"""Detect and adjust for the ANSI_QUOTES sql mode."""
mode = self._sql_mode
if not mode:
- mode = ''
+ mode = ""
elif mode.isdigit():
mode_no = int(mode)
- mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or ''
+ mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or ""
- self._server_ansiquotes = 'ANSI_QUOTES' in mode
+ self._server_ansiquotes = "ANSI_QUOTES" in mode
# as of MySQL 5.0.1
- self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode
+ self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode
- def _show_create_table(self, connection, table, charset=None,
- full_name=None):
+ def _show_create_table(
+ self, connection, table, charset=None, full_name=None
+ ):
"""Run SHOW CREATE TABLE for a ``Table``."""
if full_name is None:
@@ -2428,7 +2822,8 @@ class MySQLDialect(default.DefaultDialect):
rp = None
try:
rp = connection.execution_options(
- skip_user_error_events=True).execute(st)
+ skip_user_error_events=True
+ ).execute(st)
except exc.DBAPIError as e:
if self._extract_error_code(e.orig) == 1146:
raise exc.NoSuchTableError(full_name)
@@ -2441,8 +2836,7 @@ class MySQLDialect(default.DefaultDialect):
return sql
- def _describe_table(self, connection, table, charset=None,
- full_name=None):
+ def _describe_table(self, connection, table, charset=None, full_name=None):
"""Run DESCRIBE for a ``Table`` and return processed rows."""
if full_name is None:
@@ -2453,7 +2847,8 @@ class MySQLDialect(default.DefaultDialect):
try:
try:
rp = connection.execution_options(
- skip_user_error_events=True).execute(st)
+ skip_user_error_events=True
+ ).execute(st)
except exc.DBAPIError as e:
code = self._extract_error_code(e.orig)
if code == 1146:
@@ -2486,11 +2881,11 @@ class _DecodingRowProxy(object):
# seem to come up in DDL queries.
_encoding_compat = {
- 'koi8r': 'koi8_r',
- 'koi8u': 'koi8_u',
- 'utf16': 'utf-16-be', # MySQL's uft16 is always bigendian
- 'utf8mb4': 'utf8', # real utf8
- 'eucjpms': 'ujis',
+ "koi8r": "koi8_r",
+ "koi8u": "koi8_u",
+ "utf16": "utf-16-be", # MySQL's uft16 is always bigendian
+ "utf8mb4": "utf8", # real utf8
+ "eucjpms": "ujis",
}
def __init__(self, rowproxy, charset):