summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-10-08 15:20:48 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-10-10 01:17:25 -0400
commit2665a0c4cb3e94e6545d0b9bbcbcc39ccffebaba (patch)
treeed25383ce7e5899d7d643a11df0f8aee9f2ab959 /lib
parentbcc17b1d6e2cac3b0e45c0b17a62cf2d5fc5c5ab (diff)
downloadsqlalchemy-2665a0c4cb3e94e6545d0b9bbcbcc39ccffebaba.tar.gz
generalize scoped_session proxying and apply to asyncio elements
Reworked the proxy creation used by scoped_session() to be based on fully copied code with augmented docstrings and moved it into langhelpers. asyncio session, engine, connection can now take advantage of it so that all non-async methods are availble. Overall implementation of most important accessors / methods on AsyncConnection, etc. , including awaitable versions of invalidate, execution_options, etc. In order to support an event dispatcher on the async classes while still allowing them to hold __slots__, make some adjustments to the event system to allow that to be present, at least rudimentally. Fixes: #5628 Change-Id: I5eb6929fc1e4fdac99e4b767dcfd49672d56e2b2
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/engine/base.py30
-rw-r--r--lib/sqlalchemy/event/base.py35
-rw-r--r--lib/sqlalchemy/ext/asyncio/__init__.py2
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py153
-rw-r--r--lib/sqlalchemy/ext/asyncio/events.py29
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py97
-rw-r--r--lib/sqlalchemy/orm/events.py3
-rw-r--r--lib/sqlalchemy/orm/scoping.py99
-rw-r--r--lib/sqlalchemy/orm/session.py38
-rw-r--r--lib/sqlalchemy/pool/base.py1
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py4
-rw-r--r--lib/sqlalchemy/util/__init__.py1
-rw-r--r--lib/sqlalchemy/util/langhelpers.py155
13 files changed, 497 insertions, 150 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 9a6bdd7f3..4fbdec145 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -184,14 +184,17 @@ class Connection(Connectable):
r""" Set non-SQL options for the connection which take effect
during execution.
- The method returns a copy of this :class:`_engine.Connection`
- which references
- the same underlying DBAPI connection, but also defines the given
- execution options which will take effect for a call to
- :meth:`execute`. As the new :class:`_engine.Connection`
- references the same
- underlying resource, it's usually a good idea to ensure that the copies
- will be discarded immediately, which is implicit if used as in::
+ For a "future" style connection, this method returns this same
+ :class:`_future.Connection` object with the new options added.
+
+ For a legacy connection, this method returns a copy of this
+ :class:`_engine.Connection` which references the same underlying DBAPI
+ connection, but also defines the given execution options which will
+ take effect for a call to
+ :meth:`execute`. As the new :class:`_engine.Connection` references the
+ same underlying resource, it's usually a good idea to ensure that
+ the copies will be discarded immediately, which is implicit if used
+ as in::
result = connection.execution_options(stream_results=True).\
execute(stmt)
@@ -549,9 +552,10 @@ class Connection(Connectable):
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
- The underlying DBAPI connection is literally closed (if
- possible), and is discarded. Its source connection pool will
- typically lazily create a new connection to replace it.
+ An attempt will be made to close the underlying DBAPI connection
+ immediately; however if this operation fails, the error is logged
+ but not raised. The connection is then discarded whether or not
+ close() succeeded.
Upon the next use (where "use" typically means using the
:meth:`_engine.Connection.execute` method or similar),
@@ -580,6 +584,10 @@ class Connection(Connectable):
will at the connection pool level invoke the
:meth:`_events.PoolEvents.invalidate` event.
+ :param exception: an optional ``Exception`` instance that's the
+ reason for the invalidation. is passed along to event handlers
+ and logging functions.
+
.. seealso::
:ref:`pool_connection_invalidation`
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py
index daa6f9aea..1ba88f3d2 100644
--- a/lib/sqlalchemy/event/base.py
+++ b/lib/sqlalchemy/event/base.py
@@ -195,7 +195,14 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
dispatch_cls._event_names.append(ls.name)
if getattr(cls, "_dispatch_target", None):
- cls._dispatch_target.dispatch = dispatcher(cls)
+ the_cls = cls._dispatch_target
+ if (
+ hasattr(the_cls, "__slots__")
+ and "_slots_dispatch" in the_cls.__slots__
+ ):
+ cls._dispatch_target.dispatch = slots_dispatcher(cls)
+ else:
+ cls._dispatch_target.dispatch = dispatcher(cls)
def _remove_dispatcher(cls):
@@ -304,5 +311,29 @@ class dispatcher(object):
def __get__(self, obj, cls):
if obj is None:
return self.dispatch
- obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj)
+
+ disp = self.dispatch._for_instance(obj)
+ try:
+ obj.__dict__["dispatch"] = disp
+ except AttributeError as ae:
+ util.raise_(
+ TypeError(
+ "target %r doesn't have __dict__, should it be "
+ "defining _slots_dispatch?" % (obj,)
+ ),
+ replace_context=ae,
+ )
+ return disp
+
+
+class slots_dispatcher(dispatcher):
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self.dispatch
+
+ if hasattr(obj, "_slots_dispatch"):
+ return obj._slots_dispatch
+
+ disp = self.dispatch._for_instance(obj)
+ obj._slots_dispatch = disp
return disp
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py
index fbbc958d4..9c7d6443c 100644
--- a/lib/sqlalchemy/ext/asyncio/__init__.py
+++ b/lib/sqlalchemy/ext/asyncio/__init__.py
@@ -2,6 +2,8 @@ from .engine import AsyncConnection # noqa
from .engine import AsyncEngine # noqa
from .engine import AsyncTransaction # noqa
from .engine import create_async_engine # noqa
+from .events import AsyncConnectionEvents # noqa
+from .events import AsyncSessionEvents # noqa
from .result import AsyncMappingResult # noqa
from .result import AsyncResult # noqa
from .result import AsyncScalarResult # noqa
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 4a92fb1f2..9e4851dfc 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -8,12 +8,11 @@ from .base import StartableContext
from .result import AsyncResult
from ... import exc
from ... import util
-from ...engine import Connection
from ...engine import create_engine as _create_engine
-from ...engine import Engine
from ...engine import Result
from ...engine import Transaction
-from ...engine.base import OptionEngineMixin
+from ...future import Connection
+from ...future import Engine
from ...sql import Executable
from ...util.concurrency import greenlet_spawn
@@ -41,7 +40,24 @@ def create_async_engine(*arg, **kw):
return AsyncEngine(sync_engine)
-class AsyncConnection(StartableContext):
+class AsyncConnectable:
+ __slots__ = "_slots_dispatch"
+
+
+@util.create_proxy_methods(
+ Connection,
+ ":class:`_future.Connection`",
+ ":class:`_asyncio.AsyncConnection`",
+ classmethods=[],
+ methods=[],
+ attributes=[
+ "closed",
+ "invalidated",
+ "dialect",
+ "default_isolation_level",
+ ],
+)
+class AsyncConnection(StartableContext, AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Connection`.
:class:`_asyncio.AsyncConnection` is acquired using the
@@ -58,15 +74,23 @@ class AsyncConnection(StartableContext):
""" # noqa
+ # AsyncConnection is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncConnection that matches this one given only the
+ # "sync" elements.
__slots__ = (
"sync_engine",
"sync_connection",
)
def __init__(
- self, sync_engine: Engine, sync_connection: Optional[Connection] = None
+ self,
+ async_engine: "AsyncEngine",
+ sync_connection: Optional[Connection] = None,
):
- self.sync_engine = sync_engine
+ self.engine = async_engine
+ self.sync_engine = async_engine.sync_engine
self.sync_connection = sync_connection
async def start(self):
@@ -79,6 +103,34 @@ class AsyncConnection(StartableContext):
self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
return self
+ @property
+ def connection(self):
+ """Not implemented for async; call
+ :meth:`_asyncio.AsyncConnection.get_raw_connection`.
+
+ """
+ raise exc.InvalidRequestError(
+ "AsyncConnection.connection accessor is not implemented as the "
+ "attribute may need to reconnect on an invalidated connection. "
+ "Use the get_raw_connection() method."
+ )
+
+ async def get_raw_connection(self):
+ """Return the pooled DBAPI-level connection in use by this
+ :class:`_asyncio.AsyncConnection`.
+
+ This is typically the SQLAlchemy connection-pool proxied connection
+ which then has an attribute .connection that refers to the actual
+ DBAPI-level connection.
+ """
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(getattr, conn, "connection")
+
+ @property
+ def _proxied(self):
+ return self.sync_connection
+
def _sync_connection(self):
if not self.sync_connection:
self._raise_for_not_started()
@@ -94,6 +146,43 @@ class AsyncConnection(StartableContext):
self._sync_connection()
return AsyncTransaction(self, nested=True)
+ async def invalidate(self, exception=None):
+ """Invalidate the underlying DBAPI connection associated with
+ this :class:`_engine.Connection`.
+
+ See the method :meth:`_engine.Connection.invalidate` for full
+ detail on this method.
+
+ """
+
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.invalidate, exception=exception)
+
+ async def get_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ async def set_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ async def execution_options(self, **opt):
+ r"""Set non-SQL options for the connection which take effect
+ during execution.
+
+ This returns this :class:`_asyncio.AsyncConnection` object with
+ the new options added.
+
+ See :meth:`_future.Connection.execution_options` for full details
+ on this method.
+
+ """
+
+ conn = self._sync_connection()
+ c2 = await greenlet_spawn(conn.execution_options, **opt)
+ assert c2 is conn
+ return self
+
async def commit(self):
"""Commit the transaction that is currently in progress.
@@ -287,7 +376,19 @@ class AsyncConnection(StartableContext):
await self.close()
-class AsyncEngine:
+@util.create_proxy_methods(
+ Engine,
+ ":class:`_future.Engine`",
+ ":class:`_asyncio.AsyncEngine`",
+ classmethods=[],
+ methods=[
+ "clear_compiled_cache",
+ "update_execution_options",
+ "get_execution_options",
+ ],
+ attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
+)
+class AsyncEngine(AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Engine`.
:class:`_asyncio.AsyncEngine` is acquired using the
@@ -301,7 +402,12 @@ class AsyncEngine:
""" # noqa
- __slots__ = ("sync_engine",)
+ # AsyncEngine is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncEngine that matches this one given only the
+ # "sync" elements.
+ __slots__ = ("sync_engine", "_proxied")
_connection_cls = AsyncConnection
@@ -327,7 +433,7 @@ class AsyncEngine:
await self.conn.close()
def __init__(self, sync_engine: Engine):
- self.sync_engine = sync_engine
+ self.sync_engine = self._proxied = sync_engine
def begin(self):
"""Return a context manager which when entered will deliver an
@@ -363,7 +469,7 @@ class AsyncEngine:
"""
- return self._connection_cls(self.sync_engine)
+ return self._connection_cls(self)
async def raw_connection(self) -> Any:
"""Return a "raw" DBAPI connection from the connection pool.
@@ -375,12 +481,33 @@ class AsyncEngine:
"""
return await greenlet_spawn(self.sync_engine.raw_connection)
+ def execution_options(self, **opt):
+ """Return a new :class:`_asyncio.AsyncEngine` that will provide
+ :class:`_asyncio.AsyncConnection` objects with the given execution
+ options.
+
+ Proxied from :meth:`_future.Engine.execution_options`. See that
+ method for details.
+
+ """
+
+ return AsyncEngine(self.sync_engine.execution_options(**opt))
-class AsyncOptionEngine(OptionEngineMixin, AsyncEngine):
- pass
+ async def dispose(self):
+ """Dispose of the connection pool used by this
+ :class:`_asyncio.AsyncEngine`.
+ This will close all connection pool connections that are
+ **currently checked in**. See the documentation for the underlying
+ :meth:`_future.Engine.dispose` method for further notes.
+
+ .. seealso::
+
+ :meth:`_future.Engine.dispose`
+
+ """
-AsyncEngine._option_cls = AsyncOptionEngine
+ return await greenlet_spawn(self.sync_engine.dispose)
class AsyncTransaction(StartableContext):
diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py
new file mode 100644
index 000000000..a8daefc4b
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/events.py
@@ -0,0 +1,29 @@
+from .engine import AsyncConnectable
+from .session import AsyncSession
+from ...engine import events as engine_event
+from ...orm import events as orm_event
+
+
+class AsyncConnectionEvents(engine_event.ConnectionEvents):
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = AsyncConnectable
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes."
+ )
+
+
+class AsyncSessionEvents(orm_event.SessionEvents):
+ _target_class_doc = "SomeSession"
+ _dispatch_target = AsyncSession
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncSession.sync_session."
+ )
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index cb06aa26d..4ae1fb385 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -1,6 +1,5 @@
from typing import Any
from typing import Callable
-from typing import List
from typing import Mapping
from typing import Optional
@@ -15,6 +14,35 @@ from ...sql import Executable
from ...util.concurrency import greenlet_spawn
+@util.create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_asyncio.AsyncSession`",
+ classmethods=["object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "delete",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "get_bind",
+ "is_modified",
+ ],
+ attributes=[
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
class AsyncSession:
"""Asyncio version of :class:`_orm.Session`.
@@ -23,6 +51,16 @@ class AsyncSession:
"""
+ __slots__ = (
+ "binds",
+ "bind",
+ "sync_session",
+ "_proxied",
+ "_slots_dispatch",
+ )
+
+ dispatch = None
+
def __init__(
self,
bind: AsyncEngine = None,
@@ -31,46 +69,18 @@ class AsyncSession:
):
kw["future"] = True
if bind:
+ self.bind = engine
bind = engine._get_sync_engine(bind)
if binds:
+ self.binds = binds
binds = {
key: engine._get_sync_engine(b) for key, b in binds.items()
}
- self.sync_session = Session(bind=bind, binds=binds, **kw)
-
- def add(self, instance: object) -> None:
- """Place an object in this :class:`_asyncio.AsyncSession`.
-
- .. seealso::
-
- :meth:`_orm.Session.add`
-
- """
- self.sync_session.add(instance)
-
- def add_all(self, instances: List[object]) -> None:
- """Add the given collection of instances to this
- :class:`_asyncio.AsyncSession`."""
-
- self.sync_session.add_all(instances)
-
- def expire_all(self):
- """Expires all persistent instances within this Session.
-
- See :meth:`_orm.Session.expire_all` for usage details.
-
- """
- self.sync_session.expire_all()
-
- def expire(self, instance, attribute_names=None):
- """Expire the attributes on an instance.
-
- See :meth:`._orm.Session.expire` for usage details.
-
- """
- self.sync_session.expire()
+ self.sync_session = self._proxied = Session(
+ bind=bind, binds=binds, **kw
+ )
async def refresh(
self, instance, attribute_names=None, with_for_update=None
@@ -178,8 +188,17 @@ class AsyncSession:
:class:`.Session` object's transactional state.
"""
+
+ # POSSIBLY TODO: here, we see that the sync engine / connection
+ # that are generated from AsyncEngine / AsyncConnection don't
+ # provide any backlink from those sync objects back out to the
+ # async ones. it's not *too* big a deal since AsyncEngine/Connection
+ # are just proxies and all the state is actually in the sync
+ # version of things. However! it has to stay that way :)
sync_connection = await greenlet_spawn(self.sync_session.connection)
- return engine.AsyncConnection(sync_connection.engine, sync_connection)
+ return engine.AsyncConnection(
+ engine.AsyncEngine(sync_connection.engine), sync_connection
+ )
def begin(self, **kw):
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
@@ -218,14 +237,22 @@ class AsyncSession:
return AsyncSessionTransaction(self, nested=True)
async def rollback(self):
+ """Rollback the current transaction in progress."""
return await greenlet_spawn(self.sync_session.rollback)
async def commit(self):
+ """Commit the current transaction in progress."""
return await greenlet_spawn(self.sync_session.commit)
async def close(self):
+ """Close this :class:`_asyncio.AsyncSession`."""
return await greenlet_spawn(self.sync_session.close)
+ @classmethod
+ async def close_all(self):
+ """Close all :class:`_asyncio.AsyncSession` sessions."""
+ return await greenlet_spawn(self.sync_session.close_all)
+
async def __aenter__(self):
return self
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index 29a509cb9..4e11ebb8c 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -1371,7 +1371,8 @@ class SessionEvents(event.Events):
elif isinstance(target, Session):
return target
else:
- return None
+ # allows alternate SessionEvents-like-classes to be consulted
+ return event.Events._accept_with(target)
@classmethod
def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index 1090501ca..29d845c0a 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -9,14 +9,60 @@ from . import class_mapper
from . import exc as orm_exc
from .session import Session
from .. import exc as sa_exc
+from ..util import create_proxy_methods
from ..util import ScopedRegistry
from ..util import ThreadLocalRegistry
from ..util import warn
-
__all__ = ["scoped_session"]
+@create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_orm.scoping.scoped_session`",
+ classmethods=["close_all", "object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get_bind",
+ "is_modified",
+ "bulk_save_objects",
+ "bulk_insert_mappings",
+ "bulk_update_mappings",
+ "merge",
+ "query",
+ "refresh",
+ "rollback",
+ "scalar",
+ ],
+ attributes=[
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ "autocommit",
+ ],
+)
class scoped_session(object):
"""Provides scoped management of :class:`.Session` objects.
@@ -53,6 +99,10 @@ class scoped_session(object):
else:
self.registry = ThreadLocalRegistry(session_factory)
+ @property
+ def _proxied(self):
+ return self.registry()
+
def __call__(self, **kw):
r"""Return the current :class:`.Session`, creating it
using the :attr:`.scoped_session.session_factory` if not present.
@@ -156,50 +206,3 @@ class scoped_session(object):
ScopedSession = scoped_session
"""Old name for backwards compatibility."""
-
-
-def instrument(name):
- def do(self, *args, **kwargs):
- return getattr(self.registry(), name)(*args, **kwargs)
-
- return do
-
-
-for meth in Session.public_methods:
- setattr(scoped_session, meth, instrument(meth))
-
-
-def makeprop(name):
- def set_(self, attr):
- setattr(self.registry(), name, attr)
-
- def get(self):
- return getattr(self.registry(), name)
-
- return property(get, set_)
-
-
-for prop in (
- "bind",
- "dirty",
- "deleted",
- "new",
- "identity_map",
- "is_active",
- "autoflush",
- "no_autoflush",
- "info",
- "autocommit",
-):
- setattr(scoped_session, prop, makeprop(prop))
-
-
-def clslevel(name):
- def do(cls, *args, **kwargs):
- return getattr(Session, name)(*args, **kwargs)
-
- return classmethod(do)
-
-
-for prop in ("close_all", "object_session", "identity_key"):
- setattr(scoped_session, prop, clslevel(prop))
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index e32e05510..af0ac63e0 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -835,35 +835,6 @@ class Session(_SessionClassMethods):
"""
- public_methods = (
- "__contains__",
- "__iter__",
- "add",
- "add_all",
- "begin",
- "begin_nested",
- "close",
- "commit",
- "connection",
- "delete",
- "execute",
- "expire",
- "expire_all",
- "expunge",
- "expunge_all",
- "flush",
- "get_bind",
- "is_modified",
- "bulk_save_objects",
- "bulk_insert_mappings",
- "bulk_update_mappings",
- "merge",
- "query",
- "refresh",
- "rollback",
- "scalar",
- )
-
@util.deprecated_params(
autocommit=(
"2.0",
@@ -3028,7 +2999,14 @@ class Session(_SessionClassMethods):
will unexpire attributes on access.
"""
- state = attributes.instance_state(obj)
+ try:
+ state = attributes.instance_state(obj)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(obj),
+ replace_context=err,
+ )
+
to_attach = self._before_attach(state, obj)
state._load_pending = True
if to_attach:
diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py
index 87383fef7..68fa5fe85 100644
--- a/lib/sqlalchemy/pool/base.py
+++ b/lib/sqlalchemy/pool/base.py
@@ -509,6 +509,7 @@ class _ConnectionRecord(object):
"Soft " if soft else "",
self.connection,
)
+
if soft:
self._soft_invalidate_time = time.time()
else:
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index dfefd3b95..644ea6dc2 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -372,7 +372,7 @@ def _pytest_fn_decorator(target):
if add_positional_parameters:
spec.args.extend(add_positional_parameters)
- metadata = dict(target="target", fn="fn", name=fn.__name__)
+ metadata = dict(target="target", fn="__fn", name=fn.__name__)
metadata.update(format_argspec_plus(spec, grouped=False))
code = (
"""\
@@ -382,7 +382,7 @@ def %(name)s(%(args)s):
% metadata
)
decorated = _exec_code_in_env(
- code, {"target": target, "fn": fn}, fn.__name__
+ code, {"target": target, "__fn": fn}, fn.__name__
)
if not add_positional_parameters:
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 8ef2f0103..885f62f97 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -123,6 +123,7 @@ from .langhelpers import coerce_kw_type # noqa
from .langhelpers import constructor_copy # noqa
from .langhelpers import constructor_key # noqa
from .langhelpers import counter # noqa
+from .langhelpers import create_proxy_methods # noqa
from .langhelpers import decode_slice # noqa
from .langhelpers import decorator # noqa
from .langhelpers import dictlike_iteritems # noqa
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index e546f196d..4289db812 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -9,6 +9,7 @@
modules, classes, hierarchies, attributes, functions, and methods.
"""
+
from functools import update_wrapper
import hashlib
import inspect
@@ -462,6 +463,8 @@ def format_argspec_plus(fn, grouped=True):
passed positionally.
apply_kw
Like apply_pos, except keyword-ish args are passed as keywords.
+ apply_pos_proxied
+ Like apply_pos but omits the self/cls argument
Example::
@@ -478,16 +481,27 @@ def format_argspec_plus(fn, grouped=True):
spec = fn
args = compat.inspect_formatargspec(*spec)
+
+ apply_pos = compat.inspect_formatargspec(
+ spec[0], spec[1], spec[2], None, spec[4]
+ )
+
if spec[0]:
self_arg = spec[0][0]
+
+ apply_pos_proxied = compat.inspect_formatargspec(
+ spec[0][1:], spec[1], spec[2], None, spec[4]
+ )
+
elif spec[1]:
+ # im not sure what this is
self_arg = "%s[0]" % spec[1]
+
+ apply_pos_proxied = apply_pos
else:
self_arg = None
+ apply_pos_proxied = apply_pos
- apply_pos = compat.inspect_formatargspec(
- spec[0], spec[1], spec[2], None, spec[4]
- )
num_defaults = 0
if spec[3]:
num_defaults += len(spec[3])
@@ -513,6 +527,7 @@ def format_argspec_plus(fn, grouped=True):
self_arg=self_arg,
apply_pos=apply_pos,
apply_kw=apply_kw,
+ apply_pos_proxied=apply_pos_proxied,
)
else:
return dict(
@@ -520,6 +535,7 @@ def format_argspec_plus(fn, grouped=True):
self_arg=self_arg,
apply_pos=apply_pos[1:-1],
apply_kw=apply_kw[1:-1],
+ apply_pos_proxied=apply_pos_proxied[1:-1],
)
@@ -534,17 +550,140 @@ def format_argspec_init(method, grouped=True):
"""
if method is object.__init__:
- args = grouped and "(self)" or "self"
+ args = "(self)" if grouped else "self"
+ proxied = "()" if grouped else ""
else:
try:
return format_argspec_plus(method, grouped=grouped)
except TypeError:
args = (
- grouped
- and "(self, *args, **kwargs)"
- or "self, *args, **kwargs"
+ "(self, *args, **kwargs)"
+ if grouped
+ else "self, *args, **kwargs"
)
- return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args)
+ proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
+ return dict(
+ self_arg="self",
+ args=args,
+ apply_pos=args,
+ apply_kw=args,
+ apply_pos_proxied=proxied,
+ )
+
+
+def create_proxy_methods(
+ target_cls,
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ classmethods=(),
+ methods=(),
+ attributes=(),
+):
+ """A class decorator that will copy attributes to a proxy class.
+
+ The class to be instrumented must define a single accessor "_proxied".
+
+ """
+
+ def decorate(cls):
+ def instrument(name, clslevel=False):
+ fn = getattr(target_cls, name)
+ spec = compat.inspect_getfullargspec(fn)
+ env = {}
+
+ spec = _update_argspec_defaults_into_env(spec, env)
+ caller_argspec = format_argspec_plus(spec, grouped=False)
+
+ metadata = {
+ "name": fn.__name__,
+ "apply_pos_proxied": caller_argspec["apply_pos_proxied"],
+ "args": caller_argspec["args"],
+ "self_arg": caller_argspec["self_arg"],
+ }
+
+ if clslevel:
+ code = (
+ "def %(name)s(%(args)s):\n"
+ " return target_cls.%(name)s(%(apply_pos_proxied)s)"
+ % metadata
+ )
+ env["target_cls"] = target_cls
+ else:
+ code = (
+ "def %(name)s(%(args)s):\n"
+ " return %(self_arg)s._proxied.%(name)s(%(apply_pos_proxied)s)" # noqa E501
+ % metadata
+ )
+
+ proxy_fn = _exec_code_in_env(code, env, fn.__name__)
+ proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ proxy_fn.__doc__ = inject_docstring_text(
+ fn.__doc__,
+ ".. container:: class_bases\n\n "
+ "Proxied for the %s class on behalf of the %s class."
+ % (target_cls_sphinx_name, proxy_cls_sphinx_name),
+ 1,
+ )
+
+ if clslevel:
+ proxy_fn = classmethod(proxy_fn)
+
+ return proxy_fn
+
+ def makeprop(name):
+ attr = target_cls.__dict__.get(name, None)
+
+ if attr is not None:
+ doc = inject_docstring_text(
+ attr.__doc__,
+ ".. container:: class_bases\n\n "
+ "Proxied for the %s class on behalf of the %s class."
+ % (
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ ),
+ 1,
+ )
+ else:
+ doc = None
+
+ code = (
+ "def set_(self, attr):\n"
+ " self._proxied.%(name)s = attr\n"
+ "def get(self):\n"
+ " return self._proxied.%(name)s\n"
+ "get.__doc__ = doc\n"
+ "getset = property(get, set_)"
+ ) % {"name": name}
+
+ getset = _exec_code_in_env(code, {"doc": doc}, "getset")
+
+ return getset
+
+ for meth in methods:
+ if hasattr(cls, meth):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, meth)
+ )
+ setattr(cls, meth, instrument(meth))
+
+ for prop in attributes:
+ if hasattr(cls, prop):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, prop)
+ )
+ setattr(cls, prop, makeprop(prop))
+
+ for prop in classmethods:
+ if hasattr(cls, prop):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, prop)
+ )
+ setattr(cls, prop, instrument(prop, clslevel=True))
+
+ return cls
+
+ return decorate
def getargspec_init(method):