summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/engine/base.py16
-rw-r--r--lib/sqlalchemy/event/attr.py35
2 files changed, 46 insertions, 5 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 4bae94317..aa9358cd6 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -2189,6 +2189,8 @@ class Engine(Connectable, log.Identified):
class OptionEngine(Engine):
+ _sa_propagate_class_events = False
+
def __init__(self, proxied, execution_options):
self._proxied = proxied
self.url = proxied.url
@@ -2196,7 +2198,21 @@ class OptionEngine(Engine):
self.logging_name = proxied.logging_name
self.echo = proxied.echo
log.instance_logger(self, echoflag=self.echo)
+
+ # note: this will propagate events that are assigned to the parent
+ # engine after this OptionEngine is created. Since we share
+ # the events of the parent we also disallow class-level events
+ # to apply to the OptionEngine class directly.
+ #
+ # the other way this can work would be to transfer existing
+ # events only, using:
+ # self.dispatch._update(proxied.dispatch)
+ #
+ # that might be more appropriate however it would be a behavioral
+ # change for logic that assigns events to the parent engine and
+ # would like it to take effect for the already-created sub-engine.
self.dispatch = self.dispatch._join(proxied.dispatch)
+
self._execution_options = proxied._execution_options
self.update_execution_options(**execution_options)
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py
index 1068257cb..efa8fab42 100644
--- a/lib/sqlalchemy/event/attr.py
+++ b/lib/sqlalchemy/event/attr.py
@@ -30,7 +30,7 @@ as well as support for subclass propagation (e.g. events assigned to
"""
from __future__ import absolute_import, with_statement
-
+from .. import exc
from .. import util
from ..util import threading
from . import registry
@@ -47,6 +47,20 @@ class RefCollection(util.MemoizedSlots):
return weakref.ref(self, registry._collection_gced)
+class _empty_collection(object):
+ def append(self, element):
+ pass
+
+ def extend(self, other):
+ pass
+
+ def __iter__(self):
+ return iter([])
+
+ def clear(self):
+ pass
+
+
class _ClsLevelDispatch(RefCollection):
"""Class-level events on :class:`._Dispatch` classes."""
@@ -91,6 +105,9 @@ class _ClsLevelDispatch(RefCollection):
target = event_key.dispatch_target
assert isinstance(target, type), \
"Class-level Event targets must be classes."
+ if not getattr(target, '_sa_propagate_class_events', True):
+ raise exc.InvalidRequestError(
+ "Can't assign an event directly to the %s class" % target)
stack = [target]
while stack:
cls = stack.pop(0)
@@ -99,7 +116,7 @@ class _ClsLevelDispatch(RefCollection):
self.update_subclass(cls)
else:
if cls not in self._clslevel:
- self._clslevel[cls] = collections.deque()
+ self._assign_cls_collection(cls)
self._clslevel[cls].appendleft(event_key._listen_fn)
registry._stored_in_collection(event_key, self)
@@ -107,7 +124,9 @@ class _ClsLevelDispatch(RefCollection):
target = event_key.dispatch_target
assert isinstance(target, type), \
"Class-level Event targets must be classes."
-
+ if not getattr(target, '_sa_propagate_class_events', True):
+ raise exc.InvalidRequestError(
+ "Can't assign an event directly to the %s class" % target)
stack = [target]
while stack:
cls = stack.pop(0)
@@ -116,13 +135,19 @@ class _ClsLevelDispatch(RefCollection):
self.update_subclass(cls)
else:
if cls not in self._clslevel:
- self._clslevel[cls] = collections.deque()
+ self._assign_cls_collection(cls)
self._clslevel[cls].append(event_key._listen_fn)
registry._stored_in_collection(event_key, self)
+ def _assign_cls_collection(self, target):
+ if getattr(target, '_sa_propagate_class_events', True):
+ self._clslevel[target] = collections.deque()
+ else:
+ self._clslevel[target] = _empty_collection()
+
def update_subclass(self, target):
if target not in self._clslevel:
- self._clslevel[target] = collections.deque()
+ self._assign_cls_collection(target)
clslevel = self._clslevel[target]
for cls in target.__mro__[1:]:
if cls in self._clslevel: