diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/aiomysql.py | 38 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 117 |
2 files changed, 92 insertions, 63 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index 6c968a1e7..cab6df499 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -35,6 +35,7 @@ handling. from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util +from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only @@ -84,24 +85,32 @@ class AsyncAdapt_aiomysql_cursor: self._rows[:] = [] def execute(self, operation, parameters=None): - if parameters is None: - result = self.await_(self._cursor.execute(operation)) - else: - result = self.await_(self._cursor.execute(operation, parameters)) - - if not self.server_side: - # aiomysql has a "fake" async result, so we have to pull it out - # of that here since our default result is not async. - # we could just as easily grab "_rows" here and be done with it - # but this is safer. - self._rows = list(self.await_(self._cursor.fetchall())) - return result + return self.await_(self._execute_async(operation, parameters)) def executemany(self, operation, seq_of_parameters): return self.await_( - self._cursor.executemany(operation, seq_of_parameters) + self._executemany_async(operation, seq_of_parameters) ) + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._execute_mutex: + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + + if not self.server_side: + # aiomysql has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._execute_mutex: + return await self._cursor.executemany(operation, seq_of_parameters) + def setinputsizes(self, *inputsizes): pass @@ -161,11 +170,12 @@ class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): class AsyncAdapt_aiomysql_connection: await_ = staticmethod(await_only) - __slots__ = ("dbapi", "_connection") + __slots__ = ("dbapi", "_connection", "_execute_mutex") def __init__(self, dbapi, connection): self.dbapi = dbapi self._connection = connection + self._execute_mutex = asyncio.Lock() def ping(self, reconnect): return self.await_(self._connection.ping(reconnect)) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 7ef5e441c..4580421f6 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -122,6 +122,7 @@ from ... import pool from ... import processors from ... import util from ...sql import sqltypes +from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only @@ -369,74 +370,90 @@ class AsyncAdapt_asyncpg_cursor: ) async def _prepare_and_execute(self, operation, parameters): + adapt_connection = self._adapt_connection - if not self._adapt_connection._started: - await self._adapt_connection._start_transaction() - - if parameters is not None: - operation = operation % self._parameter_placeholders(parameters) - else: - parameters = () + async with adapt_connection._execute_mutex: - try: - prepared_stmt, attributes = await self._adapt_connection._prepare( - operation, self._invalidate_schema_cache_asof - ) + if not adapt_connection._started: + await adapt_connection._start_transaction() - if attributes: - self.description = [ - (attr.name, attr.type.oid, None, None, None, None, None) - for attr in attributes - ] + if parameters is not None: + operation = operation % self._parameter_placeholders( + parameters + ) else: - self.description = None + parameters = () - if self.server_side: - self._cursor = await prepared_stmt.cursor(*parameters) - self.rowcount = -1 - else: - self._rows = await prepared_stmt.fetch(*parameters) - status = prepared_stmt.get_statusmsg() + try: + prepared_stmt, attributes = await adapt_connection._prepare( + operation, self._invalidate_schema_cache_asof + ) - reg = re.match(r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status) - if reg: - self.rowcount = int(reg.group(1)) + if attributes: + self.description = [ + ( + attr.name, + attr.type.oid, + None, + None, + None, + None, + None, + ) + for attr in attributes + ] else: + self.description = None + + if self.server_side: + self._cursor = await prepared_stmt.cursor(*parameters) self.rowcount = -1 + else: + self._rows = await prepared_stmt.fetch(*parameters) + status = prepared_stmt.get_statusmsg() - except Exception as error: - self._handle_exception(error) + reg = re.match( + r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status + ) + if reg: + self.rowcount = int(reg.group(1)) + else: + self.rowcount = -1 - def execute(self, operation, parameters=None): - try: - self._adapt_connection.await_( - self._prepare_and_execute(operation, parameters) - ) - except Exception as error: - self._handle_exception(error) + except Exception as error: + self._handle_exception(error) - def executemany(self, operation, seq_of_parameters): + async def _executemany(self, operation, seq_of_parameters): adapt_connection = self._adapt_connection - adapt_connection.await_( - adapt_connection._check_type_cache_invalidation( + async with adapt_connection._execute_mutex: + await adapt_connection._check_type_cache_invalidation( self._invalidate_schema_cache_asof ) - ) - if not adapt_connection._started: - adapt_connection.await_(adapt_connection._start_transaction()) + if not adapt_connection._started: + await adapt_connection._start_transaction() - operation = operation % self._parameter_placeholders( - seq_of_parameters[0] + operation = operation % self._parameter_placeholders( + seq_of_parameters[0] + ) + + try: + return await self._connection.executemany( + operation, seq_of_parameters + ) + except Exception as error: + self._handle_exception(error) + + def execute(self, operation, parameters=None): + self._adapt_connection.await_( + self._prepare_and_execute(operation, parameters) ) - try: - return adapt_connection.await_( - self._connection.executemany(operation, seq_of_parameters) - ) - except Exception as error: - self._handle_exception(error) + def executemany(self, operation, seq_of_parameters): + return self._adapt_connection.await_( + self._executemany(operation, seq_of_parameters) + ) def setinputsizes(self, *inputsizes): self._inputsizes = inputsizes @@ -561,6 +578,7 @@ class AsyncAdapt_asyncpg_connection: "_started", "_prepared_statement_cache", "_invalidate_schema_cache_asof", + "_execute_mutex", ) await_ = staticmethod(await_only) @@ -574,6 +592,7 @@ class AsyncAdapt_asyncpg_connection: self._transaction = None self._started = False self._invalidate_schema_cache_asof = time.time() + self._execute_mutex = asyncio.Lock() if prepared_statement_cache_size: self._prepared_statement_cache = util.LRUCache( |