diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/event.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 2 |
2 files changed, 27 insertions, 10 deletions
diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py index 9cc3139af..cd70b3a7c 100644 --- a/lib/sqlalchemy/event.py +++ b/lib/sqlalchemy/event.py @@ -13,12 +13,12 @@ NO_RETVAL = util.symbol('NO_RETVAL') def listen(target, identifier, fn, *args, **kw): """Register a listener function for the given target. - + e.g.:: - + from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - + def unique_constraint_name(const, table): const.name = "uq_%s_%s" % ( table.name, @@ -41,12 +41,12 @@ def listen(target, identifier, fn, *args, **kw): def listens_for(target, identifier, *args, **kw): """Decorate a function as a listener for the given target + identifier. - + e.g.:: - + from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - + @event.listens_for(UniqueConstraint, "after_parent_attach") def unique_constraint_name(const, table): const.name = "uq_%s_%s" % ( @@ -205,12 +205,14 @@ class _DispatchDescriptor(object): def insert(self, obj, target, propagate): assert isinstance(target, type), \ "Class-level Event targets must be classes." - stack = [target] while stack: cls = stack.pop(0) stack.extend(cls.__subclasses__()) - self._clslevel[cls].insert(0, obj) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + self._clslevel[cls].insert(0, obj) def append(self, obj, target, propagate): assert isinstance(target, type), \ @@ -220,7 +222,20 @@ class _DispatchDescriptor(object): while stack: cls = stack.pop(0) stack.extend(cls.__subclasses__()) - self._clslevel[cls].append(obj) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + self._clslevel[cls].append(obj) + + def update_subclass(self, target): + clslevel = self._clslevel[target] + for cls in target.__mro__[1:]: + if cls in self._clslevel: + clslevel.extend([ + fn for fn + in self._clslevel[cls] + if fn not in clslevel + ]) def remove(self, obj, target): stack = [target] @@ -252,6 +267,8 @@ class _ListenerCollection(object): _exec_once = False def __init__(self, parent, target_cls): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) self.parent_listeners = parent._clslevel[target_cls] self.name = parent.__name__ self.listeners = [] diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d01c1598a..14778705d 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -99,7 +99,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, kwargs.update(new_kwargs) - return type("Session", (Sess, class_), {}) + return type("SessionMaker", (Sess, class_), {}) class SessionTransaction(object): |
