diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/event/attr.py | 35 |
2 files changed, 46 insertions, 5 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4bae94317..aa9358cd6 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -2189,6 +2189,8 @@ class Engine(Connectable, log.Identified): class OptionEngine(Engine): + _sa_propagate_class_events = False + def __init__(self, proxied, execution_options): self._proxied = proxied self.url = proxied.url @@ -2196,7 +2198,21 @@ class OptionEngine(Engine): self.logging_name = proxied.logging_name self.echo = proxied.echo log.instance_logger(self, echoflag=self.echo) + + # note: this will propagate events that are assigned to the parent + # engine after this OptionEngine is created. Since we share + # the events of the parent we also disallow class-level events + # to apply to the OptionEngine class directly. + # + # the other way this can work would be to transfer existing + # events only, using: + # self.dispatch._update(proxied.dispatch) + # + # that might be more appropriate however it would be a behavioral + # change for logic that assigns events to the parent engine and + # would like it to take effect for the already-created sub-engine. self.dispatch = self.dispatch._join(proxied.dispatch) + self._execution_options = proxied._execution_options self.update_execution_options(**execution_options) diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index 1068257cb..efa8fab42 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -30,7 +30,7 @@ as well as support for subclass propagation (e.g. events assigned to """ from __future__ import absolute_import, with_statement - +from .. import exc from .. import util from ..util import threading from . import registry @@ -47,6 +47,20 @@ class RefCollection(util.MemoizedSlots): return weakref.ref(self, registry._collection_gced) +class _empty_collection(object): + def append(self, element): + pass + + def extend(self, other): + pass + + def __iter__(self): + return iter([]) + + def clear(self): + pass + + class _ClsLevelDispatch(RefCollection): """Class-level events on :class:`._Dispatch` classes.""" @@ -91,6 +105,9 @@ class _ClsLevelDispatch(RefCollection): target = event_key.dispatch_target assert isinstance(target, type), \ "Class-level Event targets must be classes." + if not getattr(target, '_sa_propagate_class_events', True): + raise exc.InvalidRequestError( + "Can't assign an event directly to the %s class" % target) stack = [target] while stack: cls = stack.pop(0) @@ -99,7 +116,7 @@ class _ClsLevelDispatch(RefCollection): self.update_subclass(cls) else: if cls not in self._clslevel: - self._clslevel[cls] = collections.deque() + self._assign_cls_collection(cls) self._clslevel[cls].appendleft(event_key._listen_fn) registry._stored_in_collection(event_key, self) @@ -107,7 +124,9 @@ class _ClsLevelDispatch(RefCollection): target = event_key.dispatch_target assert isinstance(target, type), \ "Class-level Event targets must be classes." - + if not getattr(target, '_sa_propagate_class_events', True): + raise exc.InvalidRequestError( + "Can't assign an event directly to the %s class" % target) stack = [target] while stack: cls = stack.pop(0) @@ -116,13 +135,19 @@ class _ClsLevelDispatch(RefCollection): self.update_subclass(cls) else: if cls not in self._clslevel: - self._clslevel[cls] = collections.deque() + self._assign_cls_collection(cls) self._clslevel[cls].append(event_key._listen_fn) registry._stored_in_collection(event_key, self) + def _assign_cls_collection(self, target): + if getattr(target, '_sa_propagate_class_events', True): + self._clslevel[target] = collections.deque() + else: + self._clslevel[target] = _empty_collection() + def update_subclass(self, target): if target not in self._clslevel: - self._clslevel[target] = collections.deque() + self._assign_cls_collection(target) clslevel = self._clslevel[target] for cls in target.__mro__[1:]: if cls in self._clslevel: |