summaryrefslogtreecommitdiff
path: root/alembic/util/langhelpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/util/langhelpers.py')
-rw-r--r--alembic/util/langhelpers.py43
1 files changed, 35 insertions, 8 deletions
diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py
index 1fb0942..6c92e3c 100644
--- a/alembic/util/langhelpers.py
+++ b/alembic/util/langhelpers.py
@@ -257,30 +257,57 @@ class immutabledict(dict):
class Dispatcher(object):
- def __init__(self):
+ def __init__(self, uselist=False):
self._registry = {}
+ self.uselist = uselist
def dispatch_for(self, target, qualifier='default'):
def decorate(fn):
- assert isinstance(target, type)
- assert target not in self._registry
- self._registry[(target, qualifier)] = fn
+ if self.uselist:
+ assert target not in self._registry
+ self._registry.setdefault((target, qualifier), []).append(fn)
+ else:
+ assert target not in self._registry
+ self._registry[(target, qualifier)] = fn
return fn
return decorate
def dispatch(self, obj, qualifier='default'):
- for spcls in type(obj).__mro__:
+
+ if isinstance(obj, string_types):
+ targets = [obj]
+ elif isinstance(obj, type):
+ targets = obj.__mro__
+ else:
+ targets = type(obj).__mro__
+
+ for spcls in targets:
if qualifier != 'default' and (spcls, qualifier) in self._registry:
- return self._registry[(spcls, qualifier)]
+ return self._fn_or_list(self._registry[(spcls, qualifier)])
elif (spcls, 'default') in self._registry:
- return self._registry[(spcls, 'default')]
+ return self._fn_or_list(self._registry[(spcls, 'default')])
else:
raise ValueError("no dispatch function for object: %s" % obj)
+ def _fn_or_list(self, fn_or_list):
+ if self.uselist:
+ def go(*arg, **kw):
+ for fn in fn_or_list:
+ fn(*arg, **kw)
+ return go
+ else:
+ return fn_or_list
+
def branch(self):
"""Return a copy of this dispatcher that is independently
writable."""
d = Dispatcher()
- d._registry.update(self._registry)
+ if self.uselist:
+ d._registry.update(
+ (k, [fn for fn in self._registry[k]])
+ for k in self._registry
+ )
+ else:
+ d._registry.update(self._registry)
return d