diff options
Diffstat (limited to 'lib/sqlalchemy/engine/base.py')
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 445 |
1 files changed, 253 insertions, 192 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4a057ee59..75d03b744 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -61,10 +61,16 @@ class Connection(Connectable): """ - def __init__(self, engine, connection=None, close_with_result=False, - _branch_from=None, _execution_options=None, - _dispatch=None, - _has_events=None): + def __init__( + self, + engine, + connection=None, + close_with_result=False, + _branch_from=None, + _execution_options=None, + _dispatch=None, + _has_events=None, + ): """Construct a new Connection. The constructor here is not public and is only called only by an @@ -86,8 +92,11 @@ class Connection(Connectable): self._has_events = _branch_from._has_events self.schema_for_object = _branch_from.schema_for_object else: - self.__connection = connection \ - if connection is not None else engine.raw_connection() + self.__connection = ( + connection + if connection is not None + else engine.raw_connection() + ) self.__transaction = None self.__savepoint_seq = 0 self.should_close_with_result = close_with_result @@ -101,7 +110,8 @@ class Connection(Connectable): # want to handle any of the engine's events in that case. self.dispatch = self.dispatch._join(engine.dispatch) self._has_events = _has_events or ( - _has_events is None and engine._has_events) + _has_events is None and engine._has_events + ) assert not _execution_options self._execution_options = engine._execution_options @@ -134,7 +144,8 @@ class Connection(Connectable): _branch_from=self, _execution_options=self._execution_options, _has_events=self._has_events, - _dispatch=self.dispatch) + _dispatch=self.dispatch, + ) @property def _root(self): @@ -322,8 +333,10 @@ class Connection(Connectable): def closed(self): """Return True if this connection is closed.""" - return '_Connection__connection' not in self.__dict__ \ + return ( + "_Connection__connection" not in self.__dict__ and not self.__can_reconnect + ) @property def invalidated(self): @@ -425,7 +438,8 @@ class Connection(Connectable): if self.__transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " - "transaction is rolled back") + "transaction is rolled back" + ) self.__connection = self.engine.raw_connection(_connection=self) self.__invalid = False return self.__connection @@ -437,14 +451,15 @@ class Connection(Connectable): # dialect initializer, where the connection is not wrapped in # _ConnectionFairy - return getattr(self.__connection, 'is_valid', False) + return getattr(self.__connection, "is_valid", False) @property def _still_open_and_connection_is_valid(self): - return \ - not self.closed and \ - not self.invalidated and \ - getattr(self.__connection, 'is_valid', False) + return ( + not self.closed + and not self.invalidated + and getattr(self.__connection, "is_valid", False) + ) @property def info(self): @@ -656,7 +671,8 @@ class Connection(Connectable): if self.__transaction is not None: raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " - "is already in progress.") + "is already in progress." + ) if xid is None: xid = self.engine.dialect.create_xid() self.__transaction = TwoPhaseTransaction(self, xid) @@ -705,8 +721,10 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) finally: - if not self.__invalid and \ - self.connection._reset_agent is self.__transaction: + if ( + not self.__invalid + and self.connection._reset_agent is self.__transaction + ): self.connection._reset_agent = None self.__transaction = None else: @@ -725,8 +743,10 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) finally: - if not self.__invalid and \ - self.connection._reset_agent is self.__transaction: + if ( + not self.__invalid + and self.connection._reset_agent is self.__transaction + ): self.connection._reset_agent = None self.__transaction = None @@ -738,7 +758,7 @@ class Connection(Connectable): if name is None: self.__savepoint_seq += 1 - name = 'sa_savepoint_%s' % self.__savepoint_seq + name = "sa_savepoint_%s" % self.__savepoint_seq if self._still_open_and_connection_is_valid: self.engine.dialect.do_savepoint(self, name) return name @@ -797,7 +817,8 @@ class Connection(Connectable): assert isinstance(self.__transaction, TwoPhaseTransaction) try: self.engine.dialect.do_rollback_twophase( - self, xid, is_prepared) + self, xid, is_prepared + ) finally: if self.connection._reset_agent is self.__transaction: self.connection._reset_agent = None @@ -950,16 +971,16 @@ class Connection(Connectable): def _execute_function(self, func, multiparams, params): """Execute a sql.FunctionElement object.""" - return self._execute_clauseelement(func.select(), - multiparams, params) + return self._execute_clauseelement(func.select(), multiparams, params) def _execute_default(self, default, multiparams, params): """Execute a schema.ColumnDefault object.""" if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - default, multiparams, params = \ - fn(self, default, multiparams, params) + default, multiparams, params = fn( + self, default, multiparams, params + ) try: try: @@ -972,8 +993,7 @@ class Connection(Connectable): conn = self._revalidate_connection() dialect = self.dialect - ctx = dialect.execution_ctx_cls._init_default( - dialect, self, conn) + ctx = dialect.execution_ctx_cls._init_default(dialect, self, conn) except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) @@ -982,8 +1002,9 @@ class Connection(Connectable): self.close() if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - default, multiparams, params, ret) + self.dispatch.after_execute( + self, default, multiparams, params, ret + ) return ret @@ -992,25 +1013,25 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - ddl, multiparams, params = \ - fn(self, ddl, multiparams, params) + ddl, multiparams, params = fn(self, ddl, multiparams, params) dialect = self.dialect compiled = ddl.compile( dialect=dialect, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None) + if not self.schema_for_object.is_default + else None, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_ddl, compiled, None, - compiled + compiled, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - ddl, multiparams, params, ret) + self.dispatch.after_execute(self, ddl, multiparams, params, ret) return ret def _execute_clauseelement(self, elem, multiparams, params): @@ -1018,8 +1039,7 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - elem, multiparams, params = \ - fn(self, elem, multiparams, params) + elem, multiparams, params = fn(self, elem, multiparams, params) distilled_params = _distill_params(multiparams, params) if distilled_params: @@ -1030,38 +1050,45 @@ class Connection(Connectable): keys = [] dialect = self.dialect - if 'compiled_cache' in self._execution_options: + if "compiled_cache" in self._execution_options: key = ( - dialect, elem, tuple(sorted(keys)), + dialect, + elem, + tuple(sorted(keys)), self.schema_for_object.hash_key, - len(distilled_params) > 1 + len(distilled_params) > 1, ) - compiled_sql = self._execution_options['compiled_cache'].get(key) + compiled_sql = self._execution_options["compiled_cache"].get(key) if compiled_sql is None: compiled_sql = elem.compile( - dialect=dialect, column_keys=keys, + dialect=dialect, + column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None + if not self.schema_for_object.is_default + else None, ) - self._execution_options['compiled_cache'][key] = compiled_sql + self._execution_options["compiled_cache"][key] = compiled_sql else: compiled_sql = elem.compile( - dialect=dialect, column_keys=keys, + dialect=dialect, + column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None) + if not self.schema_for_object.is_default + else None, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_compiled, compiled_sql, distilled_params, - compiled_sql, distilled_params + compiled_sql, + distilled_params, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - elem, multiparams, params, ret) + self.dispatch.after_execute(self, elem, multiparams, params, ret) return ret def _execute_compiled(self, compiled, multiparams, params): @@ -1069,8 +1096,9 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - compiled, multiparams, params = \ - fn(self, compiled, multiparams, params) + compiled, multiparams, params = fn( + self, compiled, multiparams, params + ) dialect = self.dialect parameters = _distill_params(multiparams, params) @@ -1079,11 +1107,13 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_compiled, compiled, parameters, - compiled, parameters + compiled, + parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - compiled, multiparams, params, ret) + self.dispatch.after_execute( + self, compiled, multiparams, params, ret + ) return ret def _execute_text(self, statement, multiparams, params): @@ -1091,8 +1121,9 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - statement, multiparams, params = \ - fn(self, statement, multiparams, params) + statement, multiparams, params = fn( + self, statement, multiparams, params + ) dialect = self.dialect parameters = _distill_params(multiparams, params) @@ -1101,16 +1132,18 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_statement, statement, parameters, - statement, parameters + statement, + parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - statement, multiparams, params, ret) + self.dispatch.after_execute( + self, statement, multiparams, params, ret + ) return ret - def _execute_context(self, dialect, constructor, - statement, parameters, - *args): + def _execute_context( + self, dialect, constructor, statement, parameters, *args + ): """Create an :class:`.ExecutionContext` and execute, returning a :class:`.ResultProxy`.""" @@ -1127,31 +1160,36 @@ class Connection(Connectable): context = constructor(dialect, self, conn, *args) except BaseException as e: self._handle_dbapi_exception( - e, - util.text_type(statement), parameters, - None, None) + e, util.text_type(statement), parameters, None, None + ) if context.compiled: context.pre_exec() - cursor, statement, parameters = context.cursor, \ - context.statement, \ - context.parameters + cursor, statement, parameters = ( + context.cursor, + context.statement, + context.parameters, + ) if not context.executemany: parameters = parameters[0] if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = \ - fn(self, cursor, statement, parameters, - context, context.executemany) + statement, parameters = fn( + self, + cursor, + statement, + parameters, + context, + context.executemany, + ) if self._echo: self.engine.logger.info(statement) self.engine.logger.info( - "%r", - sql_util._repr_params(parameters, batches=10) + "%r", sql_util._repr_params(parameters, batches=10) ) evt_handled = False @@ -1164,10 +1202,8 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_executemany( - cursor, - statement, - parameters, - context) + cursor, statement, parameters, context + ) elif not parameters and context.no_parameters: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute_no_params: @@ -1176,9 +1212,8 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_execute_no_params( - cursor, - statement, - context) + cursor, statement, context + ) else: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute: @@ -1187,24 +1222,22 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_execute( - cursor, - statement, - parameters, - context) + cursor, statement, parameters, context + ) except BaseException as e: self._handle_dbapi_exception( - e, - statement, - parameters, - cursor, - context) + e, statement, parameters, cursor, context + ) if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute(self, cursor, - statement, - parameters, - context, - context.executemany) + self.dispatch.after_cursor_execute( + self, + cursor, + statement, + parameters, + context, + context.executemany, + ) if context.compiled: context.post_exec() @@ -1245,39 +1278,32 @@ class Connection(Connectable): """ if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = \ - fn(self, cursor, statement, parameters, - context, - False) + statement, parameters = fn( + self, cursor, statement, parameters, context, False + ) if self._echo: self.engine.logger.info(statement) self.engine.logger.info("%r", parameters) try: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_execute: + for fn in ( + () + if not self.dialect._has_events + else self.dialect.dispatch.do_execute + ): if fn(cursor, statement, parameters, context): break else: - self.dialect.do_execute( - cursor, - statement, - parameters, - context) + self.dialect.do_execute(cursor, statement, parameters, context) except BaseException as e: self._handle_dbapi_exception( - e, - statement, - parameters, - cursor, - context) + e, statement, parameters, cursor, context + ) if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute(self, cursor, - statement, - parameters, - context, - False) + self.dispatch.after_cursor_execute( + self, cursor, statement, parameters, context, False + ) def _safe_close_cursor(self, cursor): """Close the given cursor, catching exceptions @@ -1289,17 +1315,15 @@ class Connection(Connectable): except Exception: # log the error through the connection pool's logger. self.engine.pool.logger.error( - "Error closing cursor", exc_info=True) + "Error closing cursor", exc_info=True + ) _reentrant_error = False _is_disconnect = False - def _handle_dbapi_exception(self, - e, - statement, - parameters, - cursor, - context): + def _handle_dbapi_exception( + self, e, statement, parameters, cursor, context + ): exc_info = sys.exc_info() if context and context.exception is None: @@ -1309,15 +1333,14 @@ class Connection(Connectable): if not self._is_disconnect: self._is_disconnect = ( - isinstance(e, self.dialect.dbapi.Error) and - not self.closed and - self.dialect.is_disconnect( + isinstance(e, self.dialect.dbapi.Error) + and not self.closed + and self.dialect.is_disconnect( e, self.__connection if not self.invalidated else None, - cursor) - ) or ( - is_exit_exception and not self.closed - ) + cursor, + ) + ) or (is_exit_exception and not self.closed) if context: context.is_disconnect = self._is_disconnect @@ -1326,20 +1349,24 @@ class Connection(Connectable): if self._reentrant_error: util.raise_from_cause( - exc.DBAPIError.instance(statement, - parameters, - e, - self.dialect.dbapi.Error, - dialect=self.dialect), - exc_info + exc.DBAPIError.instance( + statement, + parameters, + e, + self.dialect.dbapi.Error, + dialect=self.dialect, + ), + exc_info, ) self._reentrant_error = True try: # non-DBAPI error - if we already got a context, # or there's no string statement, don't wrap it - should_wrap = isinstance(e, self.dialect.dbapi.Error) or \ - (statement is not None - and context is None and not is_exit_exception) + should_wrap = isinstance(e, self.dialect.dbapi.Error) or ( + statement is not None + and context is None + and not is_exit_exception + ) if should_wrap: sqlalchemy_exception = exc.DBAPIError.instance( @@ -1348,30 +1375,37 @@ class Connection(Connectable): e, self.dialect.dbapi.Error, connection_invalidated=self._is_disconnect, - dialect=self.dialect) + dialect=self.dialect, + ) else: sqlalchemy_exception = None newraise = None - if (self._has_events or self.engine._has_events) and \ - not self._execution_options.get( - 'skip_user_error_events', False): + if ( + self._has_events or self.engine._has_events + ) and not self._execution_options.get( + "skip_user_error_events", False + ): # legacy dbapi_error event if should_wrap and context: - self.dispatch.dbapi_error(self, - cursor, - statement, - parameters, - context, - e) + self.dispatch.dbapi_error( + self, cursor, statement, parameters, context, e + ) # new handle_error event ctx = ExceptionContextImpl( - e, sqlalchemy_exception, self.engine, - self, cursor, statement, - parameters, context, self._is_disconnect, - invalidate_pool_on_disconnect) + e, + sqlalchemy_exception, + self.engine, + self, + cursor, + statement, + parameters, + context, + self._is_disconnect, + invalidate_pool_on_disconnect, + ) for fn in self.dispatch.handle_error: try: @@ -1388,13 +1422,15 @@ class Connection(Connectable): if self._is_disconnect != ctx.is_disconnect: self._is_disconnect = ctx.is_disconnect if sqlalchemy_exception: - sqlalchemy_exception.connection_invalidated = \ + sqlalchemy_exception.connection_invalidated = ( ctx.is_disconnect + ) # set up potentially user-defined value for # invalidate pool. - invalidate_pool_on_disconnect = \ + invalidate_pool_on_disconnect = ( ctx.invalidate_pool_on_disconnect + ) if should_wrap and context: context.handle_dbapi_exception(e) @@ -1408,10 +1444,7 @@ class Connection(Connectable): if newraise: util.raise_from_cause(newraise, exc_info) elif should_wrap: - util.raise_from_cause( - sqlalchemy_exception, - exc_info - ) + util.raise_from_cause(sqlalchemy_exception, exc_info) else: util.reraise(*exc_info) @@ -1441,7 +1474,8 @@ class Connection(Connectable): None, e, dialect.dbapi.Error, - connection_invalidated=is_disconnect) + connection_invalidated=is_disconnect, + ) else: sqlalchemy_exception = None @@ -1449,8 +1483,17 @@ class Connection(Connectable): if engine._has_events: ctx = ExceptionContextImpl( - e, sqlalchemy_exception, engine, None, None, None, - None, None, is_disconnect, True) + e, + sqlalchemy_exception, + engine, + None, + None, + None, + None, + None, + is_disconnect, + True, + ) for fn in engine.dispatch.handle_error: try: # handler returns an exception; @@ -1463,18 +1506,15 @@ class Connection(Connectable): newraise = _raised break - if sqlalchemy_exception and \ - is_disconnect != ctx.is_disconnect: - sqlalchemy_exception.connection_invalidated = \ - is_disconnect = ctx.is_disconnect + if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = ( + is_disconnect + ) = ctx.is_disconnect if newraise: util.raise_from_cause(newraise, exc_info) elif should_wrap: - util.raise_from_cause( - sqlalchemy_exception, - exc_info - ) + util.raise_from_cause(sqlalchemy_exception, exc_info) else: util.reraise(*exc_info) @@ -1545,16 +1585,25 @@ class Connection(Connectable): return callable_(self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, **kwargs): - visitorcallable(self.dialect, self, - **kwargs).traverse_single(element) + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) class ExceptionContextImpl(ExceptionContext): """Implement the :class:`.ExceptionContext` interface.""" - def __init__(self, exception, sqlalchemy_exception, - engine, connection, cursor, statement, parameters, - context, is_disconnect, invalidate_pool_on_disconnect): + def __init__( + self, + exception, + sqlalchemy_exception, + engine, + connection, + cursor, + statement, + parameters, + context, + is_disconnect, + invalidate_pool_on_disconnect, + ): self.engine = engine self.connection = connection self.sqlalchemy_exception = sqlalchemy_exception @@ -1691,12 +1740,14 @@ class NestedTransaction(Transaction): def _do_rollback(self): if self.is_active: self.connection._rollback_to_savepoint_impl( - self._savepoint, self._parent) + self._savepoint, self._parent + ) def _do_commit(self): if self.is_active: self.connection._release_savepoint_impl( - self._savepoint, self._parent) + self._savepoint, self._parent + ) class TwoPhaseTransaction(Transaction): @@ -1771,10 +1822,16 @@ class Engine(Connectable, log.Identified): """ - def __init__(self, pool, dialect, url, - logging_name=None, echo=None, proxy=None, - execution_options=None - ): + def __init__( + self, + pool, + dialect, + url, + logging_name=None, + echo=None, + proxy=None, + execution_options=None, + ): self.pool = pool self.url = url self.dialect = dialect @@ -1805,8 +1862,7 @@ class Engine(Connectable, log.Identified): :meth:`.Engine.execution_options` """ - self._execution_options = \ - self._execution_options.union(opt) + self._execution_options = self._execution_options.union(opt) self.dispatch.set_engine_execution_options(self, opt) self.dialect.set_engine_execution_options(self, opt) @@ -1894,7 +1950,7 @@ class Engine(Connectable, log.Identified): echo = log.echo_property() def __repr__(self): - return 'Engine(%r)' % self.url + return "Engine(%r)" % self.url def dispose(self): """Dispose of the connection pool used by this :class:`.Engine`. @@ -1934,8 +1990,9 @@ class Engine(Connectable, log.Identified): else: yield connection - def _run_visitor(self, visitorcallable, element, - connection=None, **kwargs): + def _run_visitor( + self, visitorcallable, element, connection=None, **kwargs + ): with self._optional_conn_ctx_manager(connection) as conn: conn._run_visitor(visitorcallable, element, **kwargs) @@ -2122,7 +2179,8 @@ class Engine(Connectable, log.Identified): self, self._wrap_pool_connect(self.pool.connect, None), close_with_result=close_with_result, - **kwargs) + **kwargs + ) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. @@ -2159,7 +2217,8 @@ class Engine(Connectable, log.Identified): except dialect.dbapi.Error as e: if connection is None: Connection._handle_dbapi_exception_noconnection( - e, dialect, self) + e, dialect, self + ) else: util.reraise(*sys.exc_info()) @@ -2185,7 +2244,8 @@ class Engine(Connectable, log.Identified): """ return self._wrap_pool_connect( - self.pool.unique_connection, _connection) + self.pool.unique_connection, _connection + ) class OptionEngine(Engine): @@ -2225,10 +2285,11 @@ class OptionEngine(Engine): pool = property(_get_pool, _set_pool) def _get_has_events(self): - return self._proxied._has_events or \ - self.__dict__.get('_has_events', False) + return self._proxied._has_events or self.__dict__.get( + "_has_events", False + ) def _set_has_events(self, value): - self.__dict__['_has_events'] = value + self.__dict__["_has_events"] = value _has_events = property(_get_has_events, _set_has_events) |