summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/event.py35
-rw-r--r--lib/sqlalchemy/orm/session.py2
2 files changed, 27 insertions, 10 deletions
diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py
index 9cc3139af..cd70b3a7c 100644
--- a/lib/sqlalchemy/event.py
+++ b/lib/sqlalchemy/event.py
@@ -13,12 +13,12 @@ NO_RETVAL = util.symbol('NO_RETVAL')
def listen(target, identifier, fn, *args, **kw):
"""Register a listener function for the given target.
-
+
e.g.::
-
+
from sqlalchemy import event
from sqlalchemy.schema import UniqueConstraint
-
+
def unique_constraint_name(const, table):
const.name = "uq_%s_%s" % (
table.name,
@@ -41,12 +41,12 @@ def listen(target, identifier, fn, *args, **kw):
def listens_for(target, identifier, *args, **kw):
"""Decorate a function as a listener for the given target + identifier.
-
+
e.g.::
-
+
from sqlalchemy import event
from sqlalchemy.schema import UniqueConstraint
-
+
@event.listens_for(UniqueConstraint, "after_parent_attach")
def unique_constraint_name(const, table):
const.name = "uq_%s_%s" % (
@@ -205,12 +205,14 @@ class _DispatchDescriptor(object):
def insert(self, obj, target, propagate):
assert isinstance(target, type), \
"Class-level Event targets must be classes."
-
stack = [target]
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
- self._clslevel[cls].insert(0, obj)
+ if cls is not target and cls not in self._clslevel:
+ self.update_subclass(cls)
+ else:
+ self._clslevel[cls].insert(0, obj)
def append(self, obj, target, propagate):
assert isinstance(target, type), \
@@ -220,7 +222,20 @@ class _DispatchDescriptor(object):
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
- self._clslevel[cls].append(obj)
+ if cls is not target and cls not in self._clslevel:
+ self.update_subclass(cls)
+ else:
+ self._clslevel[cls].append(obj)
+
+ def update_subclass(self, target):
+ clslevel = self._clslevel[target]
+ for cls in target.__mro__[1:]:
+ if cls in self._clslevel:
+ clslevel.extend([
+ fn for fn
+ in self._clslevel[cls]
+ if fn not in clslevel
+ ])
def remove(self, obj, target):
stack = [target]
@@ -252,6 +267,8 @@ class _ListenerCollection(object):
_exec_once = False
def __init__(self, parent, target_cls):
+ if target_cls not in parent._clslevel:
+ parent.update_subclass(target_cls)
self.parent_listeners = parent._clslevel[target_cls]
self.name = parent.__name__
self.listeners = []
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index d01c1598a..14778705d 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -99,7 +99,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False,
kwargs.update(new_kwargs)
- return type("Session", (Sess, class_), {})
+ return type("SessionMaker", (Sess, class_), {})
class SessionTransaction(object):