summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mysql/__init__.py4
-rw-r--r--lib/sqlalchemy/dialects/mysql/aiomysql.py278
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqldb.py17
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py13
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py7
-rw-r--r--lib/sqlalchemy/testing/warnings.py5
6 files changed, 314 insertions, 10 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py
index 9fdc96f6f..c6781c168 100644
--- a/lib/sqlalchemy/dialects/mysql/__init__.py
+++ b/lib/sqlalchemy/dialects/mysql/__init__.py
@@ -49,6 +49,10 @@ from .base import VARCHAR
from .base import YEAR
from .dml import Insert
from .dml import insert
+from ...util import compat
+
+if compat.py3k:
+ from . import aiomysql # noqa
# default dialect
diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py
new file mode 100644
index 000000000..f560ece33
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py
@@ -0,0 +1,278 @@
+# mysql/aiomysql.py
+# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mysql+aiomysql
+ :name: aiomysql
+ :dbapi: aiomysql
+ :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://github.com/aio-libs/aiomysql
+
+The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
+
+Using a special asyncio mediation layer, the aiomysql dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname")
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+
+""" # noqa
+
+from .pymysql import MySQLDialect_pymysql
+from ... import pool
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiomysql_cursor:
+ server_side = False
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor()
+
+ # see https://github.com/aio-libs/aiomysql/issues/543
+ self._cursor = self.await_(cursor.__aenter__())
+ self._rows = []
+
+ @property
+ def description(self):
+ return self._cursor.description
+
+ @property
+ def rowcount(self):
+ return self._cursor.rowcount
+
+ @property
+ def arraysize(self):
+ return self._cursor.arraysize
+
+ @arraysize.setter
+ def arraysize(self, value):
+ self._cursor.arraysize = value
+
+ @property
+ def lastrowid(self):
+ return self._cursor.lastrowid
+
+ def close(self):
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ if parameters is None:
+ result = self.await_(self._cursor.execute(operation))
+ else:
+ result = self.await_(self._cursor.execute(operation, parameters))
+
+ if not self.server_side:
+ # aiomysql has a "fake" async result, so we have to pull it out
+ # of that here since our default result is not async.
+ # we could just as easily grab "_rows" here and be done with it
+ # but this is safer.
+ self._rows = list(self.await_(self._cursor.fetchall()))
+ return result
+
+ def executemany(self, operation, seq_of_parameters):
+ return self.await_(
+ self._cursor.executemany(operation, seq_of_parameters)
+ )
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
+
+ server_side = True
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor(
+ adapt_connection.dbapi.aiomysql.SSCursor
+ )
+
+ self._cursor = self.await_(cursor.__aenter__())
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiomysql_connection:
+ await_ = staticmethod(await_only)
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+
+ def ping(self, reconnect):
+ return self.await_(self._connection.ping(reconnect))
+
+ def character_set_name(self):
+ return self._connection.character_set_name()
+
+ def autocommit(self, value):
+ self.await_(self._connection.autocommit(value))
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_aiomysql_ss_cursor(self)
+ else:
+ return AsyncAdapt_aiomysql_cursor(self)
+
+ def rollback(self):
+ self.await_(self._connection.rollback())
+
+ def commit(self):
+ self.await_(self._connection.commit())
+
+ def close(self):
+ # it's not awaitable.
+ self._connection.close()
+
+
+class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiomysql_dbapi:
+ def __init__(self, aiomysql, pymysql):
+ self.aiomysql = aiomysql
+ self.pymysql = pymysql
+ self.paramstyle = "format"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DataError",
+ "DatabaseError",
+ "OperationalError",
+ "InterfaceError",
+ "IntegrityError",
+ "ProgrammingError",
+ "InternalError",
+ "NotSupportedError",
+ ):
+ setattr(self, name, getattr(self.aiomysql, name))
+
+ for name in (
+ "NUMBER",
+ "STRING",
+ "DATETIME",
+ "BINARY",
+ "TIMESTAMP",
+ "Binary",
+ ):
+ setattr(self, name, getattr(self.pymysql, name))
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ if async_fallback:
+ return AsyncAdaptFallback_aiomysql_connection(
+ self,
+ await_fallback(self.aiomysql.connect(*arg, **kw)),
+ )
+ else:
+ return AsyncAdapt_aiomysql_connection(
+ self,
+ await_only(self.aiomysql.connect(*arg, **kw)),
+ )
+
+
+class MySQLDialect_aiomysql(MySQLDialect_pymysql):
+ driver = "aiomysql"
+
+ supports_server_side_cursors = True
+ _sscursor = AsyncAdapt_aiomysql_ss_cursor
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_aiomysql_dbapi(
+ __import__("aiomysql"), __import__("pymysql")
+ )
+
+ @classmethod
+ def get_pool_class(self, url):
+ return pool.AsyncAdaptedQueuePool
+
+ def create_connect_args(self, url):
+ args, kw = super(MySQLDialect_aiomysql, self).create_connect_args(url)
+ if "passwd" in kw:
+ kw["password"] = kw.pop("passwd")
+ return args, kw
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_aiomysql, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ else:
+ str_e = str(e).lower()
+ return "not connected" in str_e
+
+ def _found_rows_client_flag(self):
+ from pymysql.constants import CLIENT
+
+ return CLIENT.FOUND_ROWS
+
+
+dialect = MySQLDialect_aiomysql
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py
index b20e061fb..605407f46 100644
--- a/lib/sqlalchemy/dialects/mysql/mysqldb.py
+++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py
@@ -211,16 +211,25 @@ class MySQLDialect_mysqldb(MySQLDialect):
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get("client_flag", 0)
+
+ client_flag_found_rows = self._found_rows_client_flag()
+ if client_flag_found_rows is not None:
+ client_flag |= client_flag_found_rows
+ opts["client_flag"] = client_flag
+ return [[], opts]
+
+ def _found_rows_client_flag(self):
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
self.dbapi.__name__ + ".constants.CLIENT"
).constants.CLIENT
- client_flag |= CLIENT_FLAGS.FOUND_ROWS
except (AttributeError, ImportError):
- self.supports_sane_rowcount = False
- opts["client_flag"] = client_flag
- return [[], opts]
+ return None
+ else:
+ return CLIENT_FLAGS.FOUND_ROWS
+ else:
+ return None
def _extract_error_code(self, exception):
return exception.args[0]
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
index 7f8a707d5..9c7e0420f 100644
--- a/lib/sqlalchemy/ext/asyncio/result.py
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -17,7 +17,14 @@ if util.TYPE_CHECKING:
from ...engine.result import Row
-class AsyncResult(FilterResult):
+class AsyncCommon(FilterResult):
+ async def close(self):
+ """Close this result."""
+
+ await greenlet_spawn(self._real_result.close)
+
+
+class AsyncResult(AsyncCommon):
"""An asyncio wrapper around a :class:`_result.Result` object.
The :class:`_asyncio.AsyncResult` only applies to statement executions that
@@ -370,7 +377,7 @@ class AsyncResult(FilterResult):
return AsyncMappingResult(self._real_result)
-class AsyncScalarResult(FilterResult):
+class AsyncScalarResult(AsyncCommon):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
rather than :class:`_row.Row` values.
@@ -500,7 +507,7 @@ class AsyncScalarResult(FilterResult):
return await greenlet_spawn(self._only_one_row, True, True, False)
-class AsyncMappingResult(FilterResult):
+class AsyncMappingResult(AsyncCommon):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values
rather than :class:`_engine.Row` values.
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index 9484d41d0..f31c7c137 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -114,9 +114,8 @@ class RowFetchTest(fixtures.TablesTest):
class PercentSchemaNamesTest(fixtures.TablesTest):
"""tests using percent signs, spaces in table and column names.
- This is a very fringe use case, doesn't work for MySQL
- or PostgreSQL. the requirement, "percent_schema_names",
- is marked "skip" by default.
+ This didn't work for PostgreSQL / MySQL drivers for a long time
+ but is now supported.
"""
@@ -233,6 +232,8 @@ class ServerSideCursorsTest(
elif self.engine.dialect.driver == "pymysql":
sscursor = __import__("pymysql.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver == "aiomysql":
+ return cursor.server_side
elif self.engine.dialect.driver == "mysqldb":
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
index b230bad6f..34a968aff 100644
--- a/lib/sqlalchemy/testing/warnings.py
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -30,6 +30,11 @@ def setup_filters():
warnings.filterwarnings(
"ignore", category=DeprecationWarning, message=".*inspect.get.*argspec"
)
+ warnings.filterwarnings(
+ "ignore",
+ category=DeprecationWarning,
+ message="The loop argument is deprecated",
+ )
# ignore things that are deprecated *as of* 2.0 :)
warnings.filterwarnings(