diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-12-04 16:26:44 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-12-10 17:11:46 -0500 |
commit | 5f333762db4b72c604e44f20f86d171dc249b741 (patch) | |
tree | 9d70447db467bce735a1aaeadae89d290c60dffe /lib/sqlalchemy/dialects | |
parent | c736eef8b35841af89ec19469aa496585efd3865 (diff) | |
download | sqlalchemy-5f333762db4b72c604e44f20f86d171dc249b741.tar.gz |
add aiomysql support
This is a re-gerrit of the original gerrit
merged in Ia8ad3efe3b50ce75a3bed1e020e1b82acb5f2eda
Reverted due to ongoing issues.
Fixes: #5747
Change-Id: I2b57e76b817eed8f89457a2146b523a1cab656a8
Diffstat (limited to 'lib/sqlalchemy/dialects')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/__init__.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/aiomysql.py | 278 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/mysqldb.py | 17 |
3 files changed, 295 insertions, 4 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index 9fdc96f6f..c6781c168 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -49,6 +49,10 @@ from .base import VARCHAR from .base import YEAR from .dml import Insert from .dml import insert +from ...util import compat + +if compat.py3k: + from . import aiomysql # noqa # default dialect diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py new file mode 100644 index 000000000..f560ece33 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -0,0 +1,278 @@ +# mysql/aiomysql.py +# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors <see AUTHORS +# file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: mysql+aiomysql + :name: aiomysql + :dbapi: aiomysql + :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...] + :url: https://github.com/aio-libs/aiomysql + +The aiomysql dialect is SQLAlchemy's second Python asyncio dialect. + +Using a special asyncio mediation layer, the aiomysql dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname") + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + + +""" # noqa + +from .pymysql import MySQLDialect_pymysql +from ... import pool +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiomysql_cursor: + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + # see https://github.com/aio-libs/aiomysql/issues/543 + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + self._rows[:] = [] + + def execute(self, operation, parameters=None): + if parameters is None: + result = self.await_(self._cursor.execute(operation)) + else: + result = self.await_(self._cursor.execute(operation, parameters)) + + if not self.server_side: + # aiomysql has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(self.await_(self._cursor.fetchall())) + return result + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._cursor.executemany(operation, seq_of_parameters) + ) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): + + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor( + adapt_connection.dbapi.aiomysql.SSCursor + ) + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiomysql_connection: + await_ = staticmethod(await_only) + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): + self.await_(self._connection.autocommit(value)) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiomysql_ss_cursor(self) + else: + return AsyncAdapt_aiomysql_cursor(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): + # it's not awaitable. + self._connection.close() + + +class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiomysql_dbapi: + def __init__(self, aiomysql, pymysql): + self.aiomysql = aiomysql + self.pymysql = pymysql + self.paramstyle = "format" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + ): + setattr(self, name, getattr(self.aiomysql, name)) + + for name in ( + "NUMBER", + "STRING", + "DATETIME", + "BINARY", + "TIMESTAMP", + "Binary", + ): + setattr(self, name, getattr(self.pymysql, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + if async_fallback: + return AsyncAdaptFallback_aiomysql_connection( + self, + await_fallback(self.aiomysql.connect(*arg, **kw)), + ) + else: + return AsyncAdapt_aiomysql_connection( + self, + await_only(self.aiomysql.connect(*arg, **kw)), + ) + + +class MySQLDialect_aiomysql(MySQLDialect_pymysql): + driver = "aiomysql" + + supports_server_side_cursors = True + _sscursor = AsyncAdapt_aiomysql_ss_cursor + + @classmethod + def dbapi(cls): + return AsyncAdapt_aiomysql_dbapi( + __import__("aiomysql"), __import__("pymysql") + ) + + @classmethod + def get_pool_class(self, url): + return pool.AsyncAdaptedQueuePool + + def create_connect_args(self, url): + args, kw = super(MySQLDialect_aiomysql, self).create_connect_args(url) + if "passwd" in kw: + kw["password"] = kw.pop("passwd") + return args, kw + + def is_disconnect(self, e, connection, cursor): + if super(MySQLDialect_aiomysql, self).is_disconnect( + e, connection, cursor + ): + return True + else: + str_e = str(e).lower() + return "not connected" in str_e + + def _found_rows_client_flag(self): + from pymysql.constants import CLIENT + + return CLIENT.FOUND_ROWS + + +dialect = MySQLDialect_aiomysql diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index b20e061fb..605407f46 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -211,16 +211,25 @@ class MySQLDialect_mysqldb(MySQLDialect): # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. client_flag = opts.get("client_flag", 0) + + client_flag_found_rows = self._found_rows_client_flag() + if client_flag_found_rows is not None: + client_flag |= client_flag_found_rows + opts["client_flag"] = client_flag + return [[], opts] + + def _found_rows_client_flag(self): if self.dbapi is not None: try: CLIENT_FLAGS = __import__( self.dbapi.__name__ + ".constants.CLIENT" ).constants.CLIENT - client_flag |= CLIENT_FLAGS.FOUND_ROWS except (AttributeError, ImportError): - self.supports_sane_rowcount = False - opts["client_flag"] = client_flag - return [[], opts] + return None + else: + return CLIENT_FLAGS.FOUND_ROWS + else: + return None def _extract_error_code(self, exception): return exception.args[0] |