diff options
-rw-r--r-- | CHANGES | 6 | ||||
-rw-r--r-- | doc/build/core/expression_api.rst | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 45 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 107 | ||||
-rw-r--r-- | test/sql/test_functions.py | 147 |
5 files changed, 241 insertions, 68 deletions
@@ -407,6 +407,12 @@ underneath "0.7.xx". used by combining operators.custom_op() with UnaryExpression(). + - [feature] Enhanced GenericFunction and func.* + to allow for user-defined GenericFunction + subclasses to be available via the func.* + namespace automatically by classname, + optionally using a package name as well. + - [changed] Most classes in expression.sql are no longer preceded with an underscore, i.e. Label, SelectBase, Generative, CompareMixin. diff --git a/doc/build/core/expression_api.rst b/doc/build/core/expression_api.rst index 71d2f1b56..08375078e 100644 --- a/doc/build/core/expression_api.rst +++ b/doc/build/core/expression_api.rst @@ -173,6 +173,10 @@ Classes :members: :show-inheritance: +.. autoclass:: sqlalchemy.sql.functions.GenericFunction + :members: + :show-inheritance: + .. autoclass:: Insert :members: :show-inheritance: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 41f4910a7..ce03c9c52 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1225,12 +1225,21 @@ class _FunctionGenerator(object): def __call__(self, *c, **kwargs): o = self.opts.copy() o.update(kwargs) - if len(self.__names) == 1: - func = getattr(functions, self.__names[-1].lower(), None) - if func is not None and \ - isinstance(func, type) and \ - issubclass(func, Function): - return func(*c, **o) + + tokens = len(self.__names) + + if tokens == 2: + package, fname = self.__names + elif tokens == 1: + package, fname = "_default", self.__names[0] + else: + package = None + + if package is not None and \ + package in functions._registry and \ + fname in functions._registry[package]: + func = functions._registry[package][fname] + return func(*c, **o) return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) @@ -3348,7 +3357,19 @@ class Case(ColumnElement): self.get_children()])) class FunctionElement(Executable, ColumnElement, FromClause): - """Base for SQL function-oriented constructs.""" + """Base for SQL function-oriented constructs. + + See also: + + :class:`.Function` - named SQL function. + + :data:`.func` - namespace which produces registered or ad-hoc + :class:`.Function` instances. + + :class:`.GenericFunction` - allows creation of registered function + types. + + """ packagenames = () @@ -3465,6 +3486,16 @@ class Function(FunctionElement): See the superclass :class:`.FunctionElement` for a description of public methods. + See also: + + See also: + + :data:`.func` - namespace which produces registered or ad-hoc + :class:`.Function` instances. + + :class:`.GenericFunction` - allows creation of registered function + types. + """ __visit_name__ = 'function' diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 5a480f0c3..79f1bcde2 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -10,27 +10,95 @@ from .expression import ( ) from . import operators from .visitors import VisitableType +from .. import util + +_registry = util.defaultdict(dict) class _GenericMeta(VisitableType): - def __call__(self, *args, **kwargs): - args = [_literal_as_binds(c) for c in args] - return type.__call__(self, *args, **kwargs) + def __init__(cls, clsname, bases, clsdict): + cls.name = name = clsdict.get('name', clsname) + package = clsdict.pop('package', '_default') + # legacy + if '__return_type__' in clsdict: + cls.type = clsdict['__return_type__'] + reg = _registry[package] + reg[name] = cls + super(_GenericMeta, cls).__init__(clsname, bases, clsdict) + + def __call__(cls, *args, **kwargs): + if cls.coerce_arguments: + args = [_literal_as_binds(c) for c in args] + return type.__call__(cls, *args, **kwargs) class GenericFunction(Function): + """Define a 'generic' function. + + A generic function is a pre-established :class:`.Function` + class that is instantiated automatically when called + by name from the :data:`.func` attribute. Note that + calling any name from :data:`.func` has the effect that + a new :class:`.Function` instance is created automatically, + given that name. The primary use case for defining + a :class:`.GenericFunction` class is so that a function + of a particular name may be given a fixed return type. + It can also include custom argument parsing schemes as well + as additional methods. + + Subclasses of :class:`.GenericFunction` are automatically + registered under the name of the class. For + example, a user-defined function ``as_utc()`` would + be available immediately:: + + from sqlalchemy.sql.functions import GenericFunction + from sqlalchemy.types import DateTime + + class as_utc(GenericFunction): + type = DateTime + + print select([func.as_utc()]) + + User-defined generic functions can be organized into + packages by specifying the "package" attribute when defining + :class:`.GenericFunction`. Third party libraries + containing many functions may want to use this in order + to avoid name conflicts with other systems. For example, + if our ``as_utc()`` function were part of a package + "time":: + + class as_utc(GenericFunction): + type = DateTime + package = "time" + + The above function would be available from :data:`.func` + using the package name ``time``:: + + print select([func.time.as_utc()]) + + .. versionadded:: 0.8 :class:`.GenericFunction` now supports + automatic registration of new functions as well as package + support. + + .. versionchanged:: 0.8 The attribute name ``type`` is used + to specify the function's return type at the class level. + Previously, the name ``__return_type__`` was used. This + name is still recognized for backwards-compatibility. + + """ __metaclass__ = _GenericMeta + coerce_arguments = True def __init__(self, type_=None, args=(), **kwargs): + args = [_literal_as_binds(c) for c in args] self.packagenames = [] - self.name = self.__class__.__name__ self._bind = kwargs.get('bind', None) self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *args).self_group() self.type = sqltypes.to_instance( - type_ or getattr(self, '__return_type__', None)) + type_ or getattr(self, 'type', None)) -class next_value(Function): +class next_value(GenericFunction): """Represent the 'next value', given a :class:`.Sequence` as it's single argument. @@ -41,6 +109,7 @@ class next_value(Function): """ type = sqltypes.Integer() name = "next_value" + coerce_arguments = False def __init__(self, seq, **kw): assert isinstance(seq, schema.Sequence), \ @@ -77,15 +146,15 @@ class sum(ReturnTypeFromArgs): class now(GenericFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class concat(GenericFunction): - __return_type__ = sqltypes.String + type = sqltypes.String def __init__(self, *args, **kwargs): GenericFunction.__init__(self, args=args, **kwargs) class char_length(GenericFunction): - __return_type__ = sqltypes.Integer + type = sqltypes.Integer def __init__(self, arg, **kwargs): GenericFunction.__init__(self, args=[arg], **kwargs) @@ -98,7 +167,7 @@ class random(GenericFunction): class count(GenericFunction): """The ANSI COUNT aggregate function. With no arguments, emits COUNT \*.""" - __return_type__ = sqltypes.Integer + type = sqltypes.Integer def __init__(self, expression=None, **kwargs): if expression is None: @@ -106,29 +175,29 @@ class count(GenericFunction): GenericFunction.__init__(self, args=(expression,), **kwargs) class current_date(AnsiFunction): - __return_type__ = sqltypes.Date + type = sqltypes.Date class current_time(AnsiFunction): - __return_type__ = sqltypes.Time + type = sqltypes.Time class current_timestamp(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class current_user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String class localtime(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class localtimestamp(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class session_user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String class sysdate(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 2f9c6f908..5769e4a1a 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -2,31 +2,36 @@ from test.lib.testing import eq_ import datetime from sqlalchemy import * from sqlalchemy.sql import table, column -from sqlalchemy import databases, sql, util +from sqlalchemy import sql, util from sqlalchemy.sql.compiler import BIND_TEMPLATES -from sqlalchemy.engine import default from test.lib.engines import all_dialects from sqlalchemy import types as sqltypes -from test.lib import * +from sqlalchemy.sql import functions from sqlalchemy.sql.functions import GenericFunction -from test.lib.testing import eq_ from sqlalchemy.util.compat import decimal -from test.lib import testing -from sqlalchemy.databases import * +from test.lib import testing, fixtures, AssertsCompiledSQL, engines +from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' + def tear_down(self): + functions._registry.clear() + def test_compile(self): - for dialect in all_dialects(exclude=('sybase', 'access', 'informix', 'maxdb')): + for dialect in all_dialects(exclude=('sybase', 'access', + 'informix', 'maxdb')): bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect) + self.assert_compile(func.current_timestamp(), + "CURRENT_TIMESTAMP", dialect=dialect) self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect) - if isinstance(dialect, (firebird.dialect, maxdb.dialect)): - self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect) + if dialect.name in ('firebird', 'maxdb'): + self.assert_compile(func.nosuchfunction(), + "nosuchfunction", dialect=dialect) else: - self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect) + self.assert_compile(func.nosuchfunction(), + "nosuchfunction()", dialect=dialect) # test generic function compile class fake_func(GenericFunction): @@ -38,7 +43,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( fake_func('foo'), "fake_func(%s)" % - bindtemplate % {'name':'param_1', 'position':1}, + bindtemplate % {'name': 'param_1', 'position': 1}, dialect=dialect) def test_use_labels(self): @@ -71,6 +76,44 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ]: self.assert_compile(func.random(), ret, dialect=dialect) + def test_custom_default_namespace(self): + class myfunc(GenericFunction): + pass + + assert isinstance(func.myfunc(), myfunc) + + def test_custom_type(self): + class myfunc(GenericFunction): + type = DateTime + + assert isinstance(func.myfunc().type, DateTime) + + def test_custom_legacy_type(self): + # in case someone was using this system + class myfunc(GenericFunction): + __return_type__ = DateTime + + assert isinstance(func.myfunc().type, DateTime) + + def test_custom_w_custom_name(self): + class myfunc(GenericFunction): + name = "notmyfunc" + + assert isinstance(func.notmyfunc(), myfunc) + assert not isinstance(func.myfunc(), myfunc) + + def test_custom_package_namespace(self): + def cls1(pk_name): + class myfunc(GenericFunction): + package = pk_name + return myfunc + + f1 = cls1("mypackage") + f2 = cls1("myotherpackage") + + assert isinstance(func.mypackage.myfunc(), f1) + assert isinstance(func.myotherpackage.myfunc(), f2) + def test_namespacing_conflicts(self): self.assert_compile(func.text('foo'), 'text(:text_1)') @@ -108,12 +151,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ((datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)), sqltypes.Date), ((3, 5), sqltypes.Integer), - ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric), + ((decimal.Decimal(3), decimal.Decimal(5)), + sqltypes.Numeric), (("foo", "bar"), sqltypes.String), ((datetime.datetime(2007, 10, 5, 8, 3, 34), - datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime) + datetime.datetime(2005, 10, 15, 14, 45, 33)), + sqltypes.DateTime) ]: - assert isinstance(fn(*args).type, type_), "%s / %s" % (fn(), type_) + assert isinstance(fn(*args).type, type_), \ + "%s / %s" % (fn(), type_) assert isinstance(func.concat("foo", "bar").type, sqltypes.String) @@ -129,8 +175,10 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) # test an expression with a function - self.assert_compile(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, - "lala(:lala_1, :lala_2, :param_1, mytable.myid) * myothertable.otherid") + self.assert_compile(func.lala(3, 4, literal("five"), + table1.c.myid) * table2.c.otherid, + "lala(:lala_1, :lala_2, :param_1, mytable.myid) * " + "myothertable.otherid") # test it in a SELECT self.assert_compile(select([func.count(table1.c.myid)]), @@ -140,8 +188,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(select([func.foo.bar.lala(table1.c.myid)]), "SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable") - # test the bind parameter name with a "dotted" function name is only the name - # (limits the length of the bind param name) + # test the bind parameter name with a "dotted" function name is + # only the name (limits the length of the bind param name) self.assert_compile(select([func.foo.bar.lala(12)]), "SELECT foo.bar.lala(:lala_2) AS lala_1") @@ -149,16 +197,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(func.lala.hoho(7), "lala.hoho(:hoho_1)") # test None becomes NULL - self.assert_compile(func.my_func(1,2,None,3), + self.assert_compile(func.my_func(1, 2, None, 3), "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") # test pickling self.assert_compile( - util.pickle.loads(util.pickle.dumps(func.my_func(1, 2, None, 3))), + util.pickle.loads(util.pickle.dumps( + func.my_func(1, 2, None, 3))), "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") - # assert func raises AttributeError for __bases__ attribute, since its not a class - # fixes pydoc + # assert func raises AttributeError for __bases__ attribute, since + # its not a class fixes pydoc try: func.__bases__ assert False @@ -186,8 +235,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM users, (SELECT q, z, r " "FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r " "FROM calculate(:x_2, :y_2)) AS c2 " - "WHERE users.id BETWEEN c1.z AND c2.z" - , checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5}) + "WHERE users.id BETWEEN c1.z AND c2.z", + checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5}) class ExecuteTest(fixtures.TestBase): @@ -233,12 +282,12 @@ class ExecuteTest(fixtures.TestBase): eq_(f._execution_options, {}) f = f.execution_options(foo='bar') - eq_(f._execution_options, {'foo':'bar'}) + eq_(f._execution_options, {'foo': 'bar'}) s = f.select() - eq_(s._execution_options, {'foo':'bar'}) + eq_(s._execution_options, {'foo': 'bar'}) ret = testing.db.execute(func.now().execution_options(foo='bar')) - eq_(ret.context.execution_options, {'foo':'bar'}) + eq_(ret.context.execution_options, {'foo': 'bar'}) ret.close() @@ -252,11 +301,13 @@ class ExecuteTest(fixtures.TestBase): meta = MetaData(testing.db) t = Table('t1', meta, - Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True), + Column('id', Integer, Sequence('t1idseq', optional=True), + primary_key=True), Column('value', Integer) ) t2 = Table('t2', meta, - Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True), + Column('id', Integer, Sequence('t2idseq', optional=True), + primary_key=True), Column('value', Integer, default=7), Column('stuff', String(20), onupdate="thisisstuff") ) @@ -269,20 +320,23 @@ class ExecuteTest(fixtures.TestBase): r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() id = r.inserted_primary_key[0] - assert t.select(t.c.id==id).execute().first()['value'] == 9 - t.update(values={t.c.value:func.length("asdf")}).execute() + assert t.select(t.c.id == id).execute().first()['value'] == 9 + t.update(values={t.c.value: func.length("asdf")}).execute() assert t.select().execute().first()['value'] == 4 print "--------------------------" t2.insert().execute() t2.insert(values=dict(value=func.length("one"))).execute() - t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi") + t2.insert(values=dict(value=func.length("asfda") + -19)).\ + execute(stuff="hi") res = exec_sorted(select([t2.c.value, t2.c.stuff])) eq_(res, [(-14, 'hi'), (3, None), (7, None)]) - t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff") + t2.update(values=dict(value=func.length("asdsafasd"))).\ + execute(stuff="some stuff") assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == \ - [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] + [(9, "some stuff"), (9, "some stuff"), + (9, "some stuff")] t2.delete().execute() @@ -290,11 +344,17 @@ class ExecuteTest(fixtures.TestBase): assert t2.select().execute().first()['value'] == 11 t2.update(values=dict(value=func.length("asfda"))).execute() - assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff") + eq_( + select([t2.c.value, t2.c.stuff]).execute().first(), + (5, "thisisstuff") + ) - t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute() + t2.update(values={t2.c.value: func.length("asfdaasdf"), + t2.c.stuff: "foo"}).execute() print "HI", select([t2.c.value, t2.c.stuff]).execute().first() - assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo") + eq_(select([t2.c.value, t2.c.stuff]).execute().first(), + (9, "foo") + ) finally: meta.drop_all() @@ -304,10 +364,13 @@ class ExecuteTest(fixtures.TestBase): x = func.current_date(bind=testing.db).execute().scalar() y = func.current_date(bind=testing.db).select().execute().scalar() z = func.current_date(bind=testing.db).scalar() - w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).scalar() + w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).\ + scalar() - # construct a column-based FROM object out of a function, like in [ticket:172] - s = select([sql.column('date', type_=DateTime)], from_obj=[func.current_date(bind=testing.db)]) + # construct a column-based FROM object out of a function, + # like in [ticket:172] + s = select([sql.column('date', type_=DateTime)], + from_obj=[func.current_date(bind=testing.db)]) q = s.execute().first()[s.c.date] r = s.alias('datequery').select().scalar() @@ -340,7 +403,7 @@ class ExecuteTest(fixtures.TestBase): try: table.insert().execute( {'dt': datetime.datetime(2010, 5, 1, 12, 11, 10), - 'd': datetime.date(2010, 5, 1) }) + 'd': datetime.date(2010, 5, 1)}) rs = select([extract('year', table.c.dt), extract('month', table.c.d)]).execute() row = rs.first() |