summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/util.py38
-rw-r--r--test/base/utils.py67
2 files changed, 95 insertions, 10 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 4f30f76ba..bdcaf37f0 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -4,8 +4,9 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import itertools, sys, warnings, sets, weakref
+import inspect, itertools, sets, sys, warnings, weakref
import __builtin__
+types = __import__('types')
from sqlalchemy import exceptions
@@ -181,20 +182,37 @@ class ArgSingleton(type):
return instance
def get_cls_kwargs(cls):
- """Return the full set of legal kwargs for the given `cls`."""
+ """Return the full set of inherited kwargs for the given `cls`.
+
+ Probes a class's __init__ method, collecting all named arguments. If the
+ __init__ defines a **kwargs catch-all, then the constructor is presumed to
+ pass along unrecognized keywords to it's base classes, and the collection
+ process is repeated recursively on each of the bases.
+ """
- kw = []
for c in cls.__mro__:
- cons = c.__init__
- if hasattr(cons, 'func_code'):
- for vn in cons.func_code.co_varnames:
- if vn != 'self':
- kw.append(vn)
- return kw
+ if '__init__' in c.__dict__:
+ stack = [c]
+ break
+ else:
+ return []
+
+ args = Set()
+ while stack:
+ class_ = stack.pop()
+ ctr = class_.__dict__.get('__init__', False)
+ if not ctr or not isinstance(ctr, types.FunctionType):
+ continue
+ names, _, has_kw, _ = inspect.getargspec(ctr)
+ args |= Set(names)
+ if has_kw:
+ stack.extend(class_.__bases__)
+ args.discard('self')
+ return list(args)
def get_func_kwargs(func):
"""Return the full set of legal kwargs for the given `func`."""
- return [vn for vn in func.func_code.co_varnames]
+ return inspect.getargspec(func)[0]
# from paste.deploy.converters
def asbool(obj):
diff --git a/test/base/utils.py b/test/base/utils.py
index 5a034e0b0..837eb058f 100644
--- a/test/base/utils.py
+++ b/test/base/utils.py
@@ -305,5 +305,72 @@ class DictlikeIteritemsTest(unittest.TestCase):
self._notok(duck6())
+class ArgInspectionTest(PersistTest):
+ def test_get_cls_kwargs(self):
+ class A(object):
+ def __init__(self, a):
+ pass
+ class A1(A):
+ def __init__(self, a1):
+ pass
+ class A11(A1):
+ def __init__(self, a11, **kw):
+ pass
+ class B(object):
+ def __init__(self, b, **kw):
+ pass
+ class B1(B):
+ def __init__(self, b1, **kw):
+ pass
+ class AB(A, B):
+ def __init__(self, ab):
+ pass
+ class BA(B, A):
+ def __init__(self, ba, **kwargs):
+ pass
+ class BA1(BA):
+ pass
+ class CAB(A, B):
+ pass
+ class CBA(B, A):
+ pass
+ class CAB1(A, B1):
+ pass
+ class CB1A(B1, A):
+ pass
+ class D(object):
+ pass
+
+ def test(cls, *expected):
+ self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected))
+
+ test(A, 'a')
+ test(A1, 'a1')
+ test(A11, 'a11', 'a1')
+ test(B, 'b')
+ test(B1, 'b1', 'b')
+ test(AB, 'ab')
+ test(BA, 'ba', 'b', 'a')
+ test(BA1, 'ba', 'b', 'a')
+ test(CAB, 'a')
+ test(CBA, 'b')
+ test(CAB1, 'a')
+ test(CB1A, 'b1', 'b')
+ test(D)
+
+ def test_get_func_kwargs(self):
+ def f1(): pass
+ def f2(foo): pass
+ def f3(*foo): pass
+ def f4(**foo): pass
+
+ def test(fn, *expected):
+ self.assertEquals(set(util.get_func_kwargs(fn)), set(expected))
+
+ test(f1)
+ test(f2, 'foo')
+ test(f3)
+ test(f4)
+
if __name__ == "__main__":
testenv.main()