diff options
-rw-r--r-- | lib/sqlalchemy/schema.py | 10 | ||||
-rw-r--r-- | test/sql/defaults.py | 27 |
2 files changed, 27 insertions, 10 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index b6f345be2..1dacad3de 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -735,10 +735,14 @@ class ColumnDefault(DefaultGenerator): argspec = inspect.getargspec(arg) if len(argspec[0]) == 0: self.arg = lambda ctx: arg() - elif len(argspec[0]) != 1: - raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments") else: - self.arg = arg + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + if len(argspec[0]) - defaulted > 1: + raise exceptions.ArgumentError( + "ColumnDefault Python function takes zero or one " + "positional arguments") + else: + self.arg = arg else: self.arg = arg diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 953eb7a35..854c9dc69 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -99,13 +99,26 @@ class DefaultTest(PersistTest): t.delete().execute() def testargsignature(self): - def mydefault(x, y): - pass - try: - c = ColumnDefault(mydefault) - assert False - except exceptions.ArgumentError, e: - assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e) + ex_msg = \ + "ColumnDefault Python function takes zero or one positional arguments" + + def fn1(x, y): pass + def fn2(x, y, z=3): pass + for fn in fn1, fn2: + try: + c = ColumnDefault(fn) + assert False + except exceptions.ArgumentError, e: + assert str(e) == ex_msg + + def fn3(): pass + def fn4(): pass + def fn5(x=1): pass + def fn6(x=1, y=2, z=3): pass + fn7 = list + + for fn in fn3, fn4, fn5, fn6, fn7: + c = ColumnDefault(fn) def teststandalone(self): c = testbase.db.engine.contextual_connect() |