summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Kirtland <jek@discorporate.us>2008-08-15 22:54:35 +0000
committerJason Kirtland <jek@discorporate.us>2008-08-15 22:54:35 +0000
commitaaf72e05f199d7d29a039aa6d08a7e005a01448a (patch)
tree37f854b0945cb9622febac1607bf0959c3f6291f
parentd70ed586c74feb947a412f2b5aa1496e11465cdf (diff)
downloadsqlalchemy-aaf72e05f199d7d29a039aa6d08a7e005a01448a.tar.gz
- Ignore old-style classes when building inheritance graphs. [ticket:1078]
-rw-r--r--lib/sqlalchemy/util.py20
-rw-r--r--test/base/utils.py35
2 files changed, 49 insertions, 6 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index dd1045311..c2d4bf6ab 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -397,15 +397,23 @@ def class_hierarchy(cls):
class_hierarchy(class A(object)) returns (A, object), not A plus every
class systemwide that derives from object.
+ Old-style classes are discarded and hierarchies rooted on them
+ will not be descended.
+
"""
+ if isinstance(cls, types.ClassType):
+ return list()
hier = set([cls])
process = list(cls.__mro__)
while process:
c = process.pop()
- for b in [_ for _ in c.__bases__ if _ not in hier]:
+ if isinstance(c, types.ClassType):
+ continue
+ for b in (_ for _ in c.__bases__
+ if _ not in hier and not isinstance(_, types.ClassType)):
process.append(b)
hier.add(b)
- if c.__module__ == '__builtin__':
+ if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'):
continue
for s in [_ for _ in c.__subclasses__() if _ not in hier]:
process.append(s)
@@ -414,10 +422,10 @@ def class_hierarchy(cls):
def iterate_attributes(cls):
"""iterate all the keys and attributes associated with a class, without using getattr().
-
+
Does not use getattr() so that class-sensitive descriptors (i.e. property.__get__())
are not called.
-
+
"""
keys = dir(cls)
for key in keys:
@@ -425,7 +433,7 @@ def iterate_attributes(cls):
if key in c.__dict__:
yield (key, c.__dict__[key])
break
-
+
# from paste.deploy.converters
def asbool(obj):
if isinstance(obj, (str, unicode)):
@@ -1121,7 +1129,7 @@ class ScopedRegistry(object):
return object.__new__(_TLocalRegistry)
else:
return object.__new__(cls)
-
+
def __init__(self, createfunc, scopefunc):
self.createfunc = createfunc
self.scopefunc = scopefunc
diff --git a/test/base/utils.py b/test/base/utils.py
index 3ce956a16..2c4edc692 100644
--- a/test/base/utils.py
+++ b/test/base/utils.py
@@ -884,5 +884,40 @@ class AsInterfaceTest(TestBase):
obj = {'foo': 123}
self.assertRaises(TypeError, util.as_interface, obj, cls=self.Something)
+
+class TestClassHierarchy(TestBase):
+ def test_object(self):
+ eq_(set(util.class_hierarchy(object)), set((object,)))
+
+ def test_single(self):
+ class A(object):
+ pass
+
+ class B(object):
+ pass
+
+ eq_(set(util.class_hierarchy(A)), set((A, object)))
+ eq_(set(util.class_hierarchy(B)), set((B, object)))
+
+ class C(A, B):
+ pass
+
+ eq_(set(util.class_hierarchy(A)), set((A, B, C, object)))
+ eq_(set(util.class_hierarchy(B)), set((A, B, C, object)))
+
+ def test_oldstyle_mixin(self):
+ class A(object):
+ pass
+
+ class Mixin:
+ pass
+
+ class B(A, Mixin):
+ pass
+
+ eq_(set(util.class_hierarchy(B)), set((A, B, object)))
+ eq_(set(util.class_hierarchy(Mixin)), set())
+ eq_(set(util.class_hierarchy(A)), set((A, B, object)))
+
if __name__ == "__main__":
testenv.main()