diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-10-08 15:20:48 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-10-10 01:17:25 -0400 |
commit | 2665a0c4cb3e94e6545d0b9bbcbcc39ccffebaba (patch) | |
tree | ed25383ce7e5899d7d643a11df0f8aee9f2ab959 /lib | |
parent | bcc17b1d6e2cac3b0e45c0b17a62cf2d5fc5c5ab (diff) | |
download | sqlalchemy-2665a0c4cb3e94e6545d0b9bbcbcc39ccffebaba.tar.gz |
generalize scoped_session proxying and apply to asyncio elements
Reworked the proxy creation used by scoped_session() to be
based on fully copied code with augmented docstrings and
moved it into langhelpers. asyncio session, engine,
connection can now take
advantage of it so that all non-async methods are availble.
Overall implementation of most important accessors / methods
on AsyncConnection, etc. , including awaitable versions
of invalidate, execution_options, etc.
In order to support an event dispatcher on the async
classes while still allowing them to hold __slots__,
make some adjustments to the event system to allow
that to be present, at least rudimentally.
Fixes: #5628
Change-Id: I5eb6929fc1e4fdac99e4b767dcfd49672d56e2b2
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 30 | ||||
-rw-r--r-- | lib/sqlalchemy/event/base.py | 35 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/__init__.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 153 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/events.py | 29 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 97 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/events.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/scoping.py | 99 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 38 | ||||
-rw-r--r-- | lib/sqlalchemy/pool/base.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 155 |
13 files changed, 497 insertions, 150 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 9a6bdd7f3..4fbdec145 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -184,14 +184,17 @@ class Connection(Connectable): r""" Set non-SQL options for the connection which take effect during execution. - The method returns a copy of this :class:`_engine.Connection` - which references - the same underlying DBAPI connection, but also defines the given - execution options which will take effect for a call to - :meth:`execute`. As the new :class:`_engine.Connection` - references the same - underlying resource, it's usually a good idea to ensure that the copies - will be discarded immediately, which is implicit if used as in:: + For a "future" style connection, this method returns this same + :class:`_future.Connection` object with the new options added. + + For a legacy connection, this method returns a copy of this + :class:`_engine.Connection` which references the same underlying DBAPI + connection, but also defines the given execution options which will + take effect for a call to + :meth:`execute`. As the new :class:`_engine.Connection` references the + same underlying resource, it's usually a good idea to ensure that + the copies will be discarded immediately, which is implicit if used + as in:: result = connection.execution_options(stream_results=True).\ execute(stmt) @@ -549,9 +552,10 @@ class Connection(Connectable): """Invalidate the underlying DBAPI connection associated with this :class:`_engine.Connection`. - The underlying DBAPI connection is literally closed (if - possible), and is discarded. Its source connection pool will - typically lazily create a new connection to replace it. + An attempt will be made to close the underlying DBAPI connection + immediately; however if this operation fails, the error is logged + but not raised. The connection is then discarded whether or not + close() succeeded. Upon the next use (where "use" typically means using the :meth:`_engine.Connection.execute` method or similar), @@ -580,6 +584,10 @@ class Connection(Connectable): will at the connection pool level invoke the :meth:`_events.PoolEvents.invalidate` event. + :param exception: an optional ``Exception`` instance that's the + reason for the invalidation. is passed along to event handlers + and logging functions. + .. seealso:: :ref:`pool_connection_invalidation` diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index daa6f9aea..1ba88f3d2 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -195,7 +195,14 @@ def _create_dispatcher_class(cls, classname, bases, dict_): dispatch_cls._event_names.append(ls.name) if getattr(cls, "_dispatch_target", None): - cls._dispatch_target.dispatch = dispatcher(cls) + the_cls = cls._dispatch_target + if ( + hasattr(the_cls, "__slots__") + and "_slots_dispatch" in the_cls.__slots__ + ): + cls._dispatch_target.dispatch = slots_dispatcher(cls) + else: + cls._dispatch_target.dispatch = dispatcher(cls) def _remove_dispatcher(cls): @@ -304,5 +311,29 @@ class dispatcher(object): def __get__(self, obj, cls): if obj is None: return self.dispatch - obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj) + + disp = self.dispatch._for_instance(obj) + try: + obj.__dict__["dispatch"] = disp + except AttributeError as ae: + util.raise_( + TypeError( + "target %r doesn't have __dict__, should it be " + "defining _slots_dispatch?" % (obj,) + ), + replace_context=ae, + ) + return disp + + +class slots_dispatcher(dispatcher): + def __get__(self, obj, cls): + if obj is None: + return self.dispatch + + if hasattr(obj, "_slots_dispatch"): + return obj._slots_dispatch + + disp = self.dispatch._for_instance(obj) + obj._slots_dispatch = disp return disp diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index fbbc958d4..9c7d6443c 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -2,6 +2,8 @@ from .engine import AsyncConnection # noqa from .engine import AsyncEngine # noqa from .engine import AsyncTransaction # noqa from .engine import create_async_engine # noqa +from .events import AsyncConnectionEvents # noqa +from .events import AsyncSessionEvents # noqa from .result import AsyncMappingResult # noqa from .result import AsyncResult # noqa from .result import AsyncScalarResult # noqa diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 4a92fb1f2..9e4851dfc 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -8,12 +8,11 @@ from .base import StartableContext from .result import AsyncResult from ... import exc from ... import util -from ...engine import Connection from ...engine import create_engine as _create_engine -from ...engine import Engine from ...engine import Result from ...engine import Transaction -from ...engine.base import OptionEngineMixin +from ...future import Connection +from ...future import Engine from ...sql import Executable from ...util.concurrency import greenlet_spawn @@ -41,7 +40,24 @@ def create_async_engine(*arg, **kw): return AsyncEngine(sync_engine) -class AsyncConnection(StartableContext): +class AsyncConnectable: + __slots__ = "_slots_dispatch" + + +@util.create_proxy_methods( + Connection, + ":class:`_future.Connection`", + ":class:`_asyncio.AsyncConnection`", + classmethods=[], + methods=[], + attributes=[ + "closed", + "invalidated", + "dialect", + "default_isolation_level", + ], +) +class AsyncConnection(StartableContext, AsyncConnectable): """An asyncio proxy for a :class:`_engine.Connection`. :class:`_asyncio.AsyncConnection` is acquired using the @@ -58,15 +74,23 @@ class AsyncConnection(StartableContext): """ # noqa + # AsyncConnection is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncConnection that matches this one given only the + # "sync" elements. __slots__ = ( "sync_engine", "sync_connection", ) def __init__( - self, sync_engine: Engine, sync_connection: Optional[Connection] = None + self, + async_engine: "AsyncEngine", + sync_connection: Optional[Connection] = None, ): - self.sync_engine = sync_engine + self.engine = async_engine + self.sync_engine = async_engine.sync_engine self.sync_connection = sync_connection async def start(self): @@ -79,6 +103,34 @@ class AsyncConnection(StartableContext): self.sync_connection = await (greenlet_spawn(self.sync_engine.connect)) return self + @property + def connection(self): + """Not implemented for async; call + :meth:`_asyncio.AsyncConnection.get_raw_connection`. + + """ + raise exc.InvalidRequestError( + "AsyncConnection.connection accessor is not implemented as the " + "attribute may need to reconnect on an invalidated connection. " + "Use the get_raw_connection() method." + ) + + async def get_raw_connection(self): + """Return the pooled DBAPI-level connection in use by this + :class:`_asyncio.AsyncConnection`. + + This is typically the SQLAlchemy connection-pool proxied connection + which then has an attribute .connection that refers to the actual + DBAPI-level connection. + """ + conn = self._sync_connection() + + return await greenlet_spawn(getattr, conn, "connection") + + @property + def _proxied(self): + return self.sync_connection + def _sync_connection(self): if not self.sync_connection: self._raise_for_not_started() @@ -94,6 +146,43 @@ class AsyncConnection(StartableContext): self._sync_connection() return AsyncTransaction(self, nested=True) + async def invalidate(self, exception=None): + """Invalidate the underlying DBAPI connection associated with + this :class:`_engine.Connection`. + + See the method :meth:`_engine.Connection.invalidate` for full + detail on this method. + + """ + + conn = self._sync_connection() + return await greenlet_spawn(conn.invalidate, exception=exception) + + async def get_isolation_level(self): + conn = self._sync_connection() + return await greenlet_spawn(conn.get_isolation_level) + + async def set_isolation_level(self): + conn = self._sync_connection() + return await greenlet_spawn(conn.get_isolation_level) + + async def execution_options(self, **opt): + r"""Set non-SQL options for the connection which take effect + during execution. + + This returns this :class:`_asyncio.AsyncConnection` object with + the new options added. + + See :meth:`_future.Connection.execution_options` for full details + on this method. + + """ + + conn = self._sync_connection() + c2 = await greenlet_spawn(conn.execution_options, **opt) + assert c2 is conn + return self + async def commit(self): """Commit the transaction that is currently in progress. @@ -287,7 +376,19 @@ class AsyncConnection(StartableContext): await self.close() -class AsyncEngine: +@util.create_proxy_methods( + Engine, + ":class:`_future.Engine`", + ":class:`_asyncio.AsyncEngine`", + classmethods=[], + methods=[ + "clear_compiled_cache", + "update_execution_options", + "get_execution_options", + ], + attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], +) +class AsyncEngine(AsyncConnectable): """An asyncio proxy for a :class:`_engine.Engine`. :class:`_asyncio.AsyncEngine` is acquired using the @@ -301,7 +402,12 @@ class AsyncEngine: """ # noqa - __slots__ = ("sync_engine",) + # AsyncEngine is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncEngine that matches this one given only the + # "sync" elements. + __slots__ = ("sync_engine", "_proxied") _connection_cls = AsyncConnection @@ -327,7 +433,7 @@ class AsyncEngine: await self.conn.close() def __init__(self, sync_engine: Engine): - self.sync_engine = sync_engine + self.sync_engine = self._proxied = sync_engine def begin(self): """Return a context manager which when entered will deliver an @@ -363,7 +469,7 @@ class AsyncEngine: """ - return self._connection_cls(self.sync_engine) + return self._connection_cls(self) async def raw_connection(self) -> Any: """Return a "raw" DBAPI connection from the connection pool. @@ -375,12 +481,33 @@ class AsyncEngine: """ return await greenlet_spawn(self.sync_engine.raw_connection) + def execution_options(self, **opt): + """Return a new :class:`_asyncio.AsyncEngine` that will provide + :class:`_asyncio.AsyncConnection` objects with the given execution + options. + + Proxied from :meth:`_future.Engine.execution_options`. See that + method for details. + + """ + + return AsyncEngine(self.sync_engine.execution_options(**opt)) -class AsyncOptionEngine(OptionEngineMixin, AsyncEngine): - pass + async def dispose(self): + """Dispose of the connection pool used by this + :class:`_asyncio.AsyncEngine`. + This will close all connection pool connections that are + **currently checked in**. See the documentation for the underlying + :meth:`_future.Engine.dispose` method for further notes. + + .. seealso:: + + :meth:`_future.Engine.dispose` + + """ -AsyncEngine._option_cls = AsyncOptionEngine + return await greenlet_spawn(self.sync_engine.dispose) class AsyncTransaction(StartableContext): diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py new file mode 100644 index 000000000..a8daefc4b --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/events.py @@ -0,0 +1,29 @@ +from .engine import AsyncConnectable +from .session import AsyncSession +from ...engine import events as engine_event +from ...orm import events as orm_event + + +class AsyncConnectionEvents(engine_event.ConnectionEvents): + _target_class_doc = "SomeEngine" + _dispatch_target = AsyncConnectable + + @classmethod + def _listen(cls, event_key, retval=False): + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes." + ) + + +class AsyncSessionEvents(orm_event.SessionEvents): + _target_class_doc = "SomeSession" + _dispatch_target = AsyncSession + + @classmethod + def _listen(cls, event_key, retval=False): + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index cb06aa26d..4ae1fb385 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -1,6 +1,5 @@ from typing import Any from typing import Callable -from typing import List from typing import Mapping from typing import Optional @@ -15,6 +14,35 @@ from ...sql import Executable from ...util.concurrency import greenlet_spawn +@util.create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_asyncio.AsyncSession`", + classmethods=["object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "delete", + "expire", + "expire_all", + "expunge", + "expunge_all", + "get_bind", + "is_modified", + ], + attributes=[ + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) class AsyncSession: """Asyncio version of :class:`_orm.Session`. @@ -23,6 +51,16 @@ class AsyncSession: """ + __slots__ = ( + "binds", + "bind", + "sync_session", + "_proxied", + "_slots_dispatch", + ) + + dispatch = None + def __init__( self, bind: AsyncEngine = None, @@ -31,46 +69,18 @@ class AsyncSession: ): kw["future"] = True if bind: + self.bind = engine bind = engine._get_sync_engine(bind) if binds: + self.binds = binds binds = { key: engine._get_sync_engine(b) for key, b in binds.items() } - self.sync_session = Session(bind=bind, binds=binds, **kw) - - def add(self, instance: object) -> None: - """Place an object in this :class:`_asyncio.AsyncSession`. - - .. seealso:: - - :meth:`_orm.Session.add` - - """ - self.sync_session.add(instance) - - def add_all(self, instances: List[object]) -> None: - """Add the given collection of instances to this - :class:`_asyncio.AsyncSession`.""" - - self.sync_session.add_all(instances) - - def expire_all(self): - """Expires all persistent instances within this Session. - - See :meth:`_orm.Session.expire_all` for usage details. - - """ - self.sync_session.expire_all() - - def expire(self, instance, attribute_names=None): - """Expire the attributes on an instance. - - See :meth:`._orm.Session.expire` for usage details. - - """ - self.sync_session.expire() + self.sync_session = self._proxied = Session( + bind=bind, binds=binds, **kw + ) async def refresh( self, instance, attribute_names=None, with_for_update=None @@ -178,8 +188,17 @@ class AsyncSession: :class:`.Session` object's transactional state. """ + + # POSSIBLY TODO: here, we see that the sync engine / connection + # that are generated from AsyncEngine / AsyncConnection don't + # provide any backlink from those sync objects back out to the + # async ones. it's not *too* big a deal since AsyncEngine/Connection + # are just proxies and all the state is actually in the sync + # version of things. However! it has to stay that way :) sync_connection = await greenlet_spawn(self.sync_session.connection) - return engine.AsyncConnection(sync_connection.engine, sync_connection) + return engine.AsyncConnection( + engine.AsyncEngine(sync_connection.engine), sync_connection + ) def begin(self, **kw): """Return an :class:`_asyncio.AsyncSessionTransaction` object. @@ -218,14 +237,22 @@ class AsyncSession: return AsyncSessionTransaction(self, nested=True) async def rollback(self): + """Rollback the current transaction in progress.""" return await greenlet_spawn(self.sync_session.rollback) async def commit(self): + """Commit the current transaction in progress.""" return await greenlet_spawn(self.sync_session.commit) async def close(self): + """Close this :class:`_asyncio.AsyncSession`.""" return await greenlet_spawn(self.sync_session.close) + @classmethod + async def close_all(self): + """Close all :class:`_asyncio.AsyncSession` sessions.""" + return await greenlet_spawn(self.sync_session.close_all) + async def __aenter__(self): return self diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 29a509cb9..4e11ebb8c 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1371,7 +1371,8 @@ class SessionEvents(event.Events): elif isinstance(target, Session): return target else: - return None + # allows alternate SessionEvents-like-classes to be consulted + return event.Events._accept_with(target) @classmethod def _listen(cls, event_key, raw=False, restore_load_context=False, **kw): diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 1090501ca..29d845c0a 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -9,14 +9,60 @@ from . import class_mapper from . import exc as orm_exc from .session import Session from .. import exc as sa_exc +from ..util import create_proxy_methods from ..util import ScopedRegistry from ..util import ThreadLocalRegistry from ..util import warn - __all__ = ["scoped_session"] +@create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_orm.scoping.scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get_bind", + "is_modified", + "bulk_save_objects", + "bulk_insert_mappings", + "bulk_update_mappings", + "merge", + "query", + "refresh", + "rollback", + "scalar", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + "autocommit", + ], +) class scoped_session(object): """Provides scoped management of :class:`.Session` objects. @@ -53,6 +99,10 @@ class scoped_session(object): else: self.registry = ThreadLocalRegistry(session_factory) + @property + def _proxied(self): + return self.registry() + def __call__(self, **kw): r"""Return the current :class:`.Session`, creating it using the :attr:`.scoped_session.session_factory` if not present. @@ -156,50 +206,3 @@ class scoped_session(object): ScopedSession = scoped_session """Old name for backwards compatibility.""" - - -def instrument(name): - def do(self, *args, **kwargs): - return getattr(self.registry(), name)(*args, **kwargs) - - return do - - -for meth in Session.public_methods: - setattr(scoped_session, meth, instrument(meth)) - - -def makeprop(name): - def set_(self, attr): - setattr(self.registry(), name, attr) - - def get(self): - return getattr(self.registry(), name) - - return property(get, set_) - - -for prop in ( - "bind", - "dirty", - "deleted", - "new", - "identity_map", - "is_active", - "autoflush", - "no_autoflush", - "info", - "autocommit", -): - setattr(scoped_session, prop, makeprop(prop)) - - -def clslevel(name): - def do(cls, *args, **kwargs): - return getattr(Session, name)(*args, **kwargs) - - return classmethod(do) - - -for prop in ("close_all", "object_session", "identity_key"): - setattr(scoped_session, prop, clslevel(prop)) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e32e05510..af0ac63e0 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -835,35 +835,6 @@ class Session(_SessionClassMethods): """ - public_methods = ( - "__contains__", - "__iter__", - "add", - "add_all", - "begin", - "begin_nested", - "close", - "commit", - "connection", - "delete", - "execute", - "expire", - "expire_all", - "expunge", - "expunge_all", - "flush", - "get_bind", - "is_modified", - "bulk_save_objects", - "bulk_insert_mappings", - "bulk_update_mappings", - "merge", - "query", - "refresh", - "rollback", - "scalar", - ) - @util.deprecated_params( autocommit=( "2.0", @@ -3028,7 +2999,14 @@ class Session(_SessionClassMethods): will unexpire attributes on access. """ - state = attributes.instance_state(obj) + try: + state = attributes.instance_state(obj) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(obj), + replace_context=err, + ) + to_attach = self._before_attach(state, obj) state._load_pending = True if to_attach: diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 87383fef7..68fa5fe85 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -509,6 +509,7 @@ class _ConnectionRecord(object): "Soft " if soft else "", self.connection, ) + if soft: self._soft_invalidate_time = time.time() else: diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index dfefd3b95..644ea6dc2 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -372,7 +372,7 @@ def _pytest_fn_decorator(target): if add_positional_parameters: spec.args.extend(add_positional_parameters) - metadata = dict(target="target", fn="fn", name=fn.__name__) + metadata = dict(target="target", fn="__fn", name=fn.__name__) metadata.update(format_argspec_plus(spec, grouped=False)) code = ( """\ @@ -382,7 +382,7 @@ def %(name)s(%(args)s): % metadata ) decorated = _exec_code_in_env( - code, {"target": target, "fn": fn}, fn.__name__ + code, {"target": target, "__fn": fn}, fn.__name__ ) if not add_positional_parameters: decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 8ef2f0103..885f62f97 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -123,6 +123,7 @@ from .langhelpers import coerce_kw_type # noqa from .langhelpers import constructor_copy # noqa from .langhelpers import constructor_key # noqa from .langhelpers import counter # noqa +from .langhelpers import create_proxy_methods # noqa from .langhelpers import decode_slice # noqa from .langhelpers import decorator # noqa from .langhelpers import dictlike_iteritems # noqa diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index e546f196d..4289db812 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -9,6 +9,7 @@ modules, classes, hierarchies, attributes, functions, and methods. """ + from functools import update_wrapper import hashlib import inspect @@ -462,6 +463,8 @@ def format_argspec_plus(fn, grouped=True): passed positionally. apply_kw Like apply_pos, except keyword-ish args are passed as keywords. + apply_pos_proxied + Like apply_pos but omits the self/cls argument Example:: @@ -478,16 +481,27 @@ def format_argspec_plus(fn, grouped=True): spec = fn args = compat.inspect_formatargspec(*spec) + + apply_pos = compat.inspect_formatargspec( + spec[0], spec[1], spec[2], None, spec[4] + ) + if spec[0]: self_arg = spec[0][0] + + apply_pos_proxied = compat.inspect_formatargspec( + spec[0][1:], spec[1], spec[2], None, spec[4] + ) + elif spec[1]: + # im not sure what this is self_arg = "%s[0]" % spec[1] + + apply_pos_proxied = apply_pos else: self_arg = None + apply_pos_proxied = apply_pos - apply_pos = compat.inspect_formatargspec( - spec[0], spec[1], spec[2], None, spec[4] - ) num_defaults = 0 if spec[3]: num_defaults += len(spec[3]) @@ -513,6 +527,7 @@ def format_argspec_plus(fn, grouped=True): self_arg=self_arg, apply_pos=apply_pos, apply_kw=apply_kw, + apply_pos_proxied=apply_pos_proxied, ) else: return dict( @@ -520,6 +535,7 @@ def format_argspec_plus(fn, grouped=True): self_arg=self_arg, apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1], + apply_pos_proxied=apply_pos_proxied[1:-1], ) @@ -534,17 +550,140 @@ def format_argspec_init(method, grouped=True): """ if method is object.__init__: - args = grouped and "(self)" or "self" + args = "(self)" if grouped else "self" + proxied = "()" if grouped else "" else: try: return format_argspec_plus(method, grouped=grouped) except TypeError: args = ( - grouped - and "(self, *args, **kwargs)" - or "self, *args, **kwargs" + "(self, *args, **kwargs)" + if grouped + else "self, *args, **kwargs" ) - return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args) + proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs" + return dict( + self_arg="self", + args=args, + apply_pos=args, + apply_kw=args, + apply_pos_proxied=proxied, + ) + + +def create_proxy_methods( + target_cls, + target_cls_sphinx_name, + proxy_cls_sphinx_name, + classmethods=(), + methods=(), + attributes=(), +): + """A class decorator that will copy attributes to a proxy class. + + The class to be instrumented must define a single accessor "_proxied". + + """ + + def decorate(cls): + def instrument(name, clslevel=False): + fn = getattr(target_cls, name) + spec = compat.inspect_getfullargspec(fn) + env = {} + + spec = _update_argspec_defaults_into_env(spec, env) + caller_argspec = format_argspec_plus(spec, grouped=False) + + metadata = { + "name": fn.__name__, + "apply_pos_proxied": caller_argspec["apply_pos_proxied"], + "args": caller_argspec["args"], + "self_arg": caller_argspec["self_arg"], + } + + if clslevel: + code = ( + "def %(name)s(%(args)s):\n" + " return target_cls.%(name)s(%(apply_pos_proxied)s)" + % metadata + ) + env["target_cls"] = target_cls + else: + code = ( + "def %(name)s(%(args)s):\n" + " return %(self_arg)s._proxied.%(name)s(%(apply_pos_proxied)s)" # noqa E501 + % metadata + ) + + proxy_fn = _exec_code_in_env(code, env, fn.__name__) + proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__ + proxy_fn.__doc__ = inject_docstring_text( + fn.__doc__, + ".. container:: class_bases\n\n " + "Proxied for the %s class on behalf of the %s class." + % (target_cls_sphinx_name, proxy_cls_sphinx_name), + 1, + ) + + if clslevel: + proxy_fn = classmethod(proxy_fn) + + return proxy_fn + + def makeprop(name): + attr = target_cls.__dict__.get(name, None) + + if attr is not None: + doc = inject_docstring_text( + attr.__doc__, + ".. container:: class_bases\n\n " + "Proxied for the %s class on behalf of the %s class." + % ( + target_cls_sphinx_name, + proxy_cls_sphinx_name, + ), + 1, + ) + else: + doc = None + + code = ( + "def set_(self, attr):\n" + " self._proxied.%(name)s = attr\n" + "def get(self):\n" + " return self._proxied.%(name)s\n" + "get.__doc__ = doc\n" + "getset = property(get, set_)" + ) % {"name": name} + + getset = _exec_code_in_env(code, {"doc": doc}, "getset") + + return getset + + for meth in methods: + if hasattr(cls, meth): + raise TypeError( + "class %s already has a method %s" % (cls, meth) + ) + setattr(cls, meth, instrument(meth)) + + for prop in attributes: + if hasattr(cls, prop): + raise TypeError( + "class %s already has a method %s" % (cls, prop) + ) + setattr(cls, prop, makeprop(prop)) + + for prop in classmethods: + if hasattr(cls, prop): + raise TypeError( + "class %s already has a method %s" % (cls, prop) + ) + setattr(cls, prop, instrument(prop, clslevel=True)) + + return cls + + return decorate def getargspec_init(method): |