diff options
Diffstat (limited to 'lib/sqlalchemy/ext/declarative/clsregistry.py')
-rw-r--r-- | lib/sqlalchemy/ext/declarative/clsregistry.py | 79 |
1 files changed, 56 insertions, 23 deletions
diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 08b487db3..47450c5b7 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -38,14 +38,28 @@ def add_class(classname, cls): cls._decl_class_registry[classname] = cls try: - module = cls._decl_class_registry['_sa_module_registry'] + root_module = cls._decl_class_registry['_sa_module_registry'] except KeyError: cls._decl_class_registry['_sa_module_registry'] = \ - module = _ModuleMarker('_sa_module_registry', None) - for token in cls.__module__.split("."): - module = module.get_module(token) + root_module = _ModuleMarker('_sa_module_registry', None) + + tokens = cls.__module__.split(".") + + # build up a tree like this: + # modulename: myapp.snacks.nuts + # + # myapp->snack->nuts->(classes) + # snack->nuts->(classes) + # nuts->(classes) + # + # this allows partial token paths to be used. + while tokens: + token = tokens.pop(0) + module = root_module.get_module(token) + for token in tokens: + module = module.get_module(token) + module.add_class(classname, cls) - module.add_class(classname, cls) class _MultipleClassMarker(object): """refers to multiple classes of the same name @@ -53,7 +67,8 @@ class _MultipleClassMarker(object): """ - def __init__(self, classes): + def __init__(self, classes, on_remove=None): + self.on_remove = on_remove self.contents = set([ weakref.ref(item, self._remove_item) for item in classes]) _registries.add(self) @@ -61,12 +76,14 @@ class _MultipleClassMarker(object): def __iter__(self): return (ref() for ref in self.contents) - def attempt_get(self, key): + def attempt_get(self, path, key): if len(self.contents) > 1: raise exc.InvalidRequestError( - "Multiple classes with the classname " - "%r are in the registry of this declarative " - "base. Please use a fully module-qualified path." % key) + "Multiple classes found for path \"%s\" " + "in the registry of this declarative " + "base. Please use a fully module-qualified path." % + (".".join(path + [key])) + ) else: ref = list(self.contents)[0] cls = ref() @@ -78,8 +95,20 @@ class _MultipleClassMarker(object): self.contents.remove(ref) if not self.contents: _registries.discard(self) + if self.on_remove: + self.on_remove() - def add_item(self, item, base): + def add_item(self, item): + modules = set([cls().__module__ for cls in self.contents]) + if item.__module__ in modules: + util.warn( + "This declarative base already contains a class with the " + "same class name and module name as %s.%s, and will " + "be replaced in the string-lookup table." % ( + item.__module__, + item.__name__ + ) + ) self.contents.add(weakref.ref(item, self._remove_item)) class _ModuleMarker(object): @@ -92,13 +121,17 @@ class _ModuleMarker(object): self.name = name self.contents = {} self.mod_ns = _ModNS(self) + if self.parent: + self.path = self.parent.path + [self.name] + else: + self.path = [] _registries.add(self) def __contains__(self, name): return name in self.contents def __getitem__(self, name): - return self.contents[name]() + return self.contents[name] def _remove_item(self, name): self.contents.pop(name, None) @@ -112,20 +145,20 @@ class _ModuleMarker(object): def get_module(self, name): if name not in self.contents: marker = _ModuleMarker(name, self) - self.contents[name] = lambda: marker + self.contents[name] = marker else: - marker = self.contents[name]() + marker = self.contents[name] return marker def add_class(self, name, cls): if name in self.contents: - util.warn( - "This declarative base already contains a class with the " - "same class name and module name as %r, and will be replaced " - "in the string-lookup table." % cls) + existing = self.contents[name] + existing.add_item(cls) + else: + existing = self.contents[name] = \ + _MultipleClassMarker([cls], + on_remove=lambda: self._remove_item(name)) - self.contents[name] = weakref.ref(cls, - lambda ref: self._remove_item(name)) class _ModNS(object): @@ -138,12 +171,12 @@ class _ModNS(object): except KeyError: pass else: - value = value() if value is not None: if isinstance(value, _ModuleMarker): return value.mod_ns else: - return value + assert isinstance(value, _MultipleClassMarker) + return value.attempt_get(self.__parent.path, key) raise AttributeError("Module %r has no mapped classes " "registered under the name %r" % (self.__parent.name, key)) @@ -179,7 +212,7 @@ class _GetTable(object): def _determine_container(key, value): if isinstance(value, _MultipleClassMarker): - value = value.attempt_get(key) + value = value.attempt_get([], key) return _GetColumns(value) def _resolver(cls, prop): |