diff options
author | Jason Kirtland <jek@discorporate.us> | 2008-08-15 22:54:35 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2008-08-15 22:54:35 +0000 |
commit | aaf72e05f199d7d29a039aa6d08a7e005a01448a (patch) | |
tree | 37f854b0945cb9622febac1607bf0959c3f6291f | |
parent | d70ed586c74feb947a412f2b5aa1496e11465cdf (diff) | |
download | sqlalchemy-aaf72e05f199d7d29a039aa6d08a7e005a01448a.tar.gz |
- Ignore old-style classes when building inheritance graphs. [ticket:1078]
-rw-r--r-- | lib/sqlalchemy/util.py | 20 | ||||
-rw-r--r-- | test/base/utils.py | 35 |
2 files changed, 49 insertions, 6 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index dd1045311..c2d4bf6ab 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -397,15 +397,23 @@ def class_hierarchy(cls): class_hierarchy(class A(object)) returns (A, object), not A plus every class systemwide that derives from object. + Old-style classes are discarded and hierarchies rooted on them + will not be descended. + """ + if isinstance(cls, types.ClassType): + return list() hier = set([cls]) process = list(cls.__mro__) while process: c = process.pop() - for b in [_ for _ in c.__bases__ if _ not in hier]: + if isinstance(c, types.ClassType): + continue + for b in (_ for _ in c.__bases__ + if _ not in hier and not isinstance(_, types.ClassType)): process.append(b) hier.add(b) - if c.__module__ == '__builtin__': + if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'): continue for s in [_ for _ in c.__subclasses__() if _ not in hier]: process.append(s) @@ -414,10 +422,10 @@ def class_hierarchy(cls): def iterate_attributes(cls): """iterate all the keys and attributes associated with a class, without using getattr(). - + Does not use getattr() so that class-sensitive descriptors (i.e. property.__get__()) are not called. - + """ keys = dir(cls) for key in keys: @@ -425,7 +433,7 @@ def iterate_attributes(cls): if key in c.__dict__: yield (key, c.__dict__[key]) break - + # from paste.deploy.converters def asbool(obj): if isinstance(obj, (str, unicode)): @@ -1121,7 +1129,7 @@ class ScopedRegistry(object): return object.__new__(_TLocalRegistry) else: return object.__new__(cls) - + def __init__(self, createfunc, scopefunc): self.createfunc = createfunc self.scopefunc = scopefunc diff --git a/test/base/utils.py b/test/base/utils.py index 3ce956a16..2c4edc692 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -884,5 +884,40 @@ class AsInterfaceTest(TestBase): obj = {'foo': 123} self.assertRaises(TypeError, util.as_interface, obj, cls=self.Something) + +class TestClassHierarchy(TestBase): + def test_object(self): + eq_(set(util.class_hierarchy(object)), set((object,))) + + def test_single(self): + class A(object): + pass + + class B(object): + pass + + eq_(set(util.class_hierarchy(A)), set((A, object))) + eq_(set(util.class_hierarchy(B)), set((B, object))) + + class C(A, B): + pass + + eq_(set(util.class_hierarchy(A)), set((A, B, C, object))) + eq_(set(util.class_hierarchy(B)), set((A, B, C, object))) + + def test_oldstyle_mixin(self): + class A(object): + pass + + class Mixin: + pass + + class B(A, Mixin): + pass + + eq_(set(util.class_hierarchy(B)), set((A, B, object))) + eq_(set(util.class_hierarchy(Mixin)), set()) + eq_(set(util.class_hierarchy(A)), set((A, B, object))) + if __name__ == "__main__": testenv.main() |