diff options
-rw-r--r-- | lib/sqlalchemy/util.py | 38 | ||||
-rw-r--r-- | test/base/utils.py | 67 |
2 files changed, 95 insertions, 10 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 4f30f76ba..bdcaf37f0 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,8 +4,9 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import itertools, sys, warnings, sets, weakref +import inspect, itertools, sets, sys, warnings, weakref import __builtin__ +types = __import__('types') from sqlalchemy import exceptions @@ -181,20 +182,37 @@ class ArgSingleton(type): return instance def get_cls_kwargs(cls): - """Return the full set of legal kwargs for the given `cls`.""" + """Return the full set of inherited kwargs for the given `cls`. + + Probes a class's __init__ method, collecting all named arguments. If the + __init__ defines a **kwargs catch-all, then the constructor is presumed to + pass along unrecognized keywords to it's base classes, and the collection + process is repeated recursively on each of the bases. + """ - kw = [] for c in cls.__mro__: - cons = c.__init__ - if hasattr(cons, 'func_code'): - for vn in cons.func_code.co_varnames: - if vn != 'self': - kw.append(vn) - return kw + if '__init__' in c.__dict__: + stack = [c] + break + else: + return [] + + args = Set() + while stack: + class_ = stack.pop() + ctr = class_.__dict__.get('__init__', False) + if not ctr or not isinstance(ctr, types.FunctionType): + continue + names, _, has_kw, _ = inspect.getargspec(ctr) + args |= Set(names) + if has_kw: + stack.extend(class_.__bases__) + args.discard('self') + return list(args) def get_func_kwargs(func): """Return the full set of legal kwargs for the given `func`.""" - return [vn for vn in func.func_code.co_varnames] + return inspect.getargspec(func)[0] # from paste.deploy.converters def asbool(obj): diff --git a/test/base/utils.py b/test/base/utils.py index 5a034e0b0..837eb058f 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -305,5 +305,72 @@ class DictlikeIteritemsTest(unittest.TestCase): self._notok(duck6()) +class ArgInspectionTest(PersistTest): + def test_get_cls_kwargs(self): + class A(object): + def __init__(self, a): + pass + class A1(A): + def __init__(self, a1): + pass + class A11(A1): + def __init__(self, a11, **kw): + pass + class B(object): + def __init__(self, b, **kw): + pass + class B1(B): + def __init__(self, b1, **kw): + pass + class AB(A, B): + def __init__(self, ab): + pass + class BA(B, A): + def __init__(self, ba, **kwargs): + pass + class BA1(BA): + pass + class CAB(A, B): + pass + class CBA(B, A): + pass + class CAB1(A, B1): + pass + class CB1A(B1, A): + pass + class D(object): + pass + + def test(cls, *expected): + self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected)) + + test(A, 'a') + test(A1, 'a1') + test(A11, 'a11', 'a1') + test(B, 'b') + test(B1, 'b1', 'b') + test(AB, 'ab') + test(BA, 'ba', 'b', 'a') + test(BA1, 'ba', 'b', 'a') + test(CAB, 'a') + test(CBA, 'b') + test(CAB1, 'a') + test(CB1A, 'b1', 'b') + test(D) + + def test_get_func_kwargs(self): + def f1(): pass + def f2(foo): pass + def f3(*foo): pass + def f4(**foo): pass + + def test(fn, *expected): + self.assertEquals(set(util.get_func_kwargs(fn)), set(expected)) + + test(f1) + test(f2, 'foo') + test(f3) + test(f4) + if __name__ == "__main__": testenv.main() |