summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mysql/asyncmy.py
diff options
context:
space:
mode:
authorlong2ice <long2ice@gmail.com>2021-09-16 11:08:25 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-09-17 11:20:19 -0400
commit11eecfacb7b36c209c1ad726f5e5b7525860977b (patch)
tree6bc6e32e3defe9099e217fcb2ba5205339dca200 /lib/sqlalchemy/dialects/mysql/asyncmy.py
parentdb847ca4e52de0e70d4993d1b7ac4de1c947b864 (diff)
downloadsqlalchemy-11eecfacb7b36c209c1ad726f5e5b7525860977b.tar.gz
Add `asyncmy` support
Added initial support for the ``asyncmy`` asyncio database driver for MySQL and MariaDB. This driver is very new, however appears to be the only current alternative to the ``aiomysql`` driver which currently appears to be unmaintained and is not working with current Python versions. Much thanks to long2ice for the pull request for this dialect. Fixes: #6993 Closes: #7000 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7000 Pull-request-sha: f7d6c811fc72324a83c8af635bbca8b268b0098e Change-Id: I4ef54b43334feff7e3a710fc4de6821437f3bb68
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/asyncmy.py')
-rw-r--r--lib/sqlalchemy/dialects/mysql/asyncmy.py340
1 files changed, 340 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py
new file mode 100644
index 000000000..f312cf79b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py
@@ -0,0 +1,340 @@
+# mysql/asyncmy.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mysql+asyncmy
+ :name: asyncmy
+ :dbapi: asyncmy
+ :connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://github.com/long2ice/asyncmy
+
+.. note:: The asyncmy dialect as of September, 2021 was added to provide
+ MySQL/MariaDB asyncio compatibility given that the :ref:`aiomysql` database
+ driver has become unmaintained, however asyncmy is itself very new.
+
+Using a special asyncio mediation layer, the asyncmy dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
+
+
+""" # noqa
+
+import contextlib
+
+from .pymysql import MySQLDialect_pymysql
+from ... import pool
+from ... import util
+from ...util.concurrency import asyncio
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_asyncmy_cursor:
+ server_side = False
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "await_",
+ "_cursor",
+ "_rows",
+ )
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor()
+
+ self._cursor = self.await_(cursor.__aenter__())
+ self._rows = []
+
+ @property
+ def description(self):
+ return self._cursor.description
+
+ @property
+ def rowcount(self):
+ return self._cursor.rowcount
+
+ @property
+ def arraysize(self):
+ return self._cursor.arraysize
+
+ @arraysize.setter
+ def arraysize(self, value):
+ self._cursor.arraysize = value
+
+ @property
+ def lastrowid(self):
+ return self._cursor.lastrowid
+
+ def close(self):
+ # note we aren't actually closing the cursor here,
+ # we are just letting GC do it. to allow this to be async
+ # we would need the Result to change how it does "Safe close cursor".
+ # MySQL "cursors" don't actually have state to be "closed" besides
+ # exhausting rows, which we already have done for sync cursor.
+ # another option would be to emulate aiosqlite dialect and assign
+ # cursor only if we are doing server side cursor operation.
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ return self.await_(self._execute_async(operation, parameters))
+
+ def executemany(self, operation, seq_of_parameters):
+ return self.await_(
+ self._executemany_async(operation, seq_of_parameters)
+ )
+
+ async def _execute_async(self, operation, parameters):
+ async with self._adapt_connection._mutex_and_adapt_errors():
+ if parameters is None:
+ result = await self._cursor.execute(operation)
+ else:
+ result = await self._cursor.execute(operation, parameters)
+
+ if not self.server_side:
+ # asyncmy has a "fake" async result, so we have to pull it out
+ # of that here since our default result is not async.
+ # we could just as easily grab "_rows" here and be done with it
+ # but this is safer.
+ self._rows = list(await self._cursor.fetchall())
+ return result
+
+ async def _executemany_async(self, operation, seq_of_parameters):
+ async with self._adapt_connection._mutex_and_adapt_errors():
+ return await self._cursor.executemany(operation, seq_of_parameters)
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
+ __slots__ = ()
+ server_side = True
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ adapt_connection._ss_cursors.add(self)
+
+ cursor = self._connection.cursor(
+ adapt_connection.dbapi.asyncmy.cursors.SSCursor
+ )
+
+ self._cursor = self.await_(cursor.__aenter__())
+
+ def close(self):
+ try:
+ if self._cursor is not None:
+ self.await_(self._cursor.fetchall())
+ self.await_(self._cursor.close())
+ self._cursor = None
+ finally:
+ self._adapt_connection._ss_cursors.discard(self)
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_asyncmy_connection:
+ await_ = staticmethod(await_only)
+ __slots__ = ("dbapi", "_connection", "_execute_mutex", "_ss_cursors")
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+ self._execute_mutex = asyncio.Lock()
+ self._ss_cursors = set()
+
+ @contextlib.asynccontextmanager
+ async def _mutex_and_adapt_errors(self):
+ async with self._execute_mutex:
+ try:
+ yield
+ except AttributeError:
+ raise self.dbapi.InternalError(
+ "network operation failed due to asyncmy attribute error"
+ )
+
+ def ping(self, reconnect):
+ assert not reconnect
+ return self.await_(self._do_ping())
+
+ async def _do_ping(self):
+ async with self._mutex_and_adapt_errors():
+ return await self._connection.ping(False)
+
+ def character_set_name(self):
+ return self._connection.character_set_name()
+
+ def autocommit(self, value):
+ self.await_(self._connection.autocommit(value))
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_asyncmy_ss_cursor(self)
+ else:
+ return AsyncAdapt_asyncmy_cursor(self)
+
+ def _shutdown_ss_cursors(self):
+ for curs in list(self._ss_cursors):
+ curs.close()
+
+ def rollback(self):
+ self._shutdown_ss_cursors()
+ self.await_(self._connection.rollback())
+
+ def commit(self):
+ self._shutdown_ss_cursors()
+ self.await_(self._connection.commit())
+
+ def close(self):
+ self._shutdown_ss_cursors()
+ # it's not awaitable.
+ self._connection.close()
+
+
+class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_asyncmy_dbapi:
+ def __init__(self, asyncmy, pymysql):
+ self.asyncmy = asyncmy
+ self.pymysql = pymysql
+ self.paramstyle = "format"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DataError",
+ "DatabaseError",
+ "OperationalError",
+ "InterfaceError",
+ "IntegrityError",
+ "ProgrammingError",
+ "InternalError",
+ "NotSupportedError",
+ ):
+ setattr(self, name, getattr(self.asyncmy.errors, name))
+
+ for name in (
+ "NUMBER",
+ "STRING",
+ "DATETIME",
+ "BINARY",
+ "TIMESTAMP",
+ "Binary",
+ ):
+ setattr(self, name, getattr(self.pymysql, name))
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_asyncmy_connection(
+ self,
+ await_fallback(self.asyncmy.connect(*arg, **kw)),
+ )
+ else:
+ return AsyncAdapt_asyncmy_connection(
+ self,
+ await_only(self.asyncmy.connect(*arg, **kw)),
+ )
+
+
+class MySQLDialect_asyncmy(MySQLDialect_pymysql):
+ driver = "asyncmy"
+ supports_statement_cache = True
+
+ supports_server_side_cursors = True
+ _sscursor = AsyncAdapt_asyncmy_ss_cursor
+
+ is_async = True
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_asyncmy_dbapi(
+ __import__("asyncmy"), __import__("pymysql")
+ )
+
+ @classmethod
+ def get_pool_class(cls, url):
+
+ async_fallback = url.query.get("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return pool.FallbackAsyncAdaptedQueuePool
+ else:
+ return pool.AsyncAdaptedQueuePool
+
+ def create_connect_args(self, url):
+ return super(MySQLDialect_asyncmy, self).create_connect_args(
+ url, _translate_args=dict(username="user", database="db")
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_asyncmy, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ else:
+ str_e = str(e).lower()
+ return (
+ "not connected" in str_e or "network operation failed" in str_e
+ )
+
+ def _found_rows_client_flag(self):
+ from pymysql.constants import CLIENT
+
+ return CLIENT.FOUND_ROWS
+
+
+dialect = MySQLDialect_asyncmy