diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/mysqldb.py')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/mysqldb.py | 119 |
1 files changed, 68 insertions, 51 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index edac816fe..6d42f5c04 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -45,8 +45,12 @@ The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`. """ -from .base import (MySQLDialect, MySQLExecutionContext, - MySQLCompiler, MySQLIdentifierPreparer) +from .base import ( + MySQLDialect, + MySQLExecutionContext, + MySQLCompiler, + MySQLIdentifierPreparer, +) from .base import TEXT from ... import sql from ... import util @@ -54,10 +58,9 @@ import re class MySQLExecutionContext_mysqldb(MySQLExecutionContext): - @property def rowcount(self): - if hasattr(self, '_rowcount'): + if hasattr(self, "_rowcount"): return self._rowcount else: return self.cursor.rowcount @@ -72,14 +75,14 @@ class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): class MySQLDialect_mysqldb(MySQLDialect): - driver = 'mysqldb' + driver = "mysqldb" supports_unicode_statements = True supports_sane_rowcount = True supports_sane_multi_rowcount = True supports_native_decimal = True - default_paramstyle = 'format' + default_paramstyle = "format" execution_ctx_cls = MySQLExecutionContext_mysqldb statement_compiler = MySQLCompiler_mysqldb preparer = MySQLIdentifierPreparer_mysqldb @@ -87,24 +90,23 @@ class MySQLDialect_mysqldb(MySQLDialect): def __init__(self, server_side_cursors=False, **kwargs): super(MySQLDialect_mysqldb, self).__init__(**kwargs) self.server_side_cursors = server_side_cursors - self._mysql_dbapi_version = self._parse_dbapi_version( - self.dbapi.__version__) if self.dbapi is not None \ - and hasattr(self.dbapi, '__version__') else (0, 0, 0) + self._mysql_dbapi_version = ( + self._parse_dbapi_version(self.dbapi.__version__) + if self.dbapi is not None and hasattr(self.dbapi, "__version__") + else (0, 0, 0) + ) def _parse_dbapi_version(self, version): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', version) + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) if m: - return tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) else: return (0, 0, 0) @util.langhelpers.memoized_property def supports_server_side_cursors(self): try: - cursors = __import__('MySQLdb.cursors').cursors + cursors = __import__("MySQLdb.cursors").cursors self._sscursor = cursors.SSCursor return True except (ImportError, AttributeError): @@ -112,7 +114,7 @@ class MySQLDialect_mysqldb(MySQLDialect): @classmethod def dbapi(cls): - return __import__('MySQLdb') + return __import__("MySQLdb") def do_ping(self, dbapi_connection): try: @@ -135,67 +137,74 @@ class MySQLDialect_mysqldb(MySQLDialect): # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 # specific issue w/ the utf8mb4_bin collation and unicode returns - has_utf8mb4_bin = self.server_version_info > (5, ) and \ - connection.scalar( - "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" - % ( - self.identifier_preparer.quote("Charset"), - self.identifier_preparer.quote("Collation") - )) + has_utf8mb4_bin = self.server_version_info > ( + 5, + ) and connection.scalar( + "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation"), + ) + ) if has_utf8mb4_bin: additional_tests = [ - sql.collate(sql.cast( - sql.literal_column( - "'test collated returns'"), - TEXT(charset='utf8mb4')), "utf8mb4_bin") + sql.collate( + sql.cast( + sql.literal_column("'test collated returns'"), + TEXT(charset="utf8mb4"), + ), + "utf8mb4_bin", + ) ] else: additional_tests = [] return super(MySQLDialect_mysqldb, self)._check_unicode_returns( - connection, additional_tests) + connection, additional_tests + ) def create_connect_args(self, url): - opts = url.translate_connect_args(database='db', username='user', - password='passwd') + opts = url.translate_connect_args( + database="db", username="user", password="passwd" + ) opts.update(url.query) - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'connect_timeout', int) - util.coerce_kw_type(opts, 'read_timeout', int) - util.coerce_kw_type(opts, 'write_timeout', int) - util.coerce_kw_type(opts, 'client_flag', int) - util.coerce_kw_type(opts, 'local_infile', int) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "read_timeout", int) + util.coerce_kw_type(opts, "write_timeout", int) + util.coerce_kw_type(opts, "client_flag", int) + util.coerce_kw_type(opts, "local_infile", int) # Note: using either of the below will cause all strings to be # returned as Unicode, both in raw SQL operations and with column # types like String and MSString. - util.coerce_kw_type(opts, 'use_unicode', bool) - util.coerce_kw_type(opts, 'charset', str) + util.coerce_kw_type(opts, "use_unicode", bool) + util.coerce_kw_type(opts, "charset", str) # Rich values 'cursorclass' and 'conv' are not supported via # query string. ssl = {} - keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher'] + keys = ["ssl_ca", "ssl_key", "ssl_cert", "ssl_capath", "ssl_cipher"] for key in keys: if key in opts: ssl[key[4:]] = opts[key] util.coerce_kw_type(ssl, key[4:], str) del opts[key] if ssl: - opts['ssl'] = ssl + opts["ssl"] = ssl # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. - client_flag = opts.get('client_flag', 0) + client_flag = opts.get("client_flag", 0) if self.dbapi is not None: try: CLIENT_FLAGS = __import__( - self.dbapi.__name__ + '.constants.CLIENT' + self.dbapi.__name__ + ".constants.CLIENT" ).constants.CLIENT client_flag |= CLIENT_FLAGS.FOUND_ROWS except (AttributeError, ImportError): self.supports_sane_rowcount = False - opts['client_flag'] = client_flag + opts["client_flag"] = client_flag return [[], opts] def _extract_error_code(self, exception): @@ -213,22 +222,30 @@ class MySQLDialect_mysqldb(MySQLDialect): "No 'character_set_name' can be detected with " "this MySQL-Python version; " "please upgrade to a recent version of MySQL-Python. " - "Assuming latin1.") - return 'latin1' + "Assuming latin1." + ) + return "latin1" else: return cset_name() - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'AUTOCOMMIT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ] + ) def _set_isolation_level(self, connection, level): - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit(True) else: connection.autocommit(False) - super(MySQLDialect_mysqldb, self)._set_isolation_level(connection, - level) + super(MySQLDialect_mysqldb, self)._set_isolation_level( + connection, level + ) dialect = MySQLDialect_mysqldb |