diff options
Diffstat (limited to 'lib/sqlalchemy/sql/functions.py')
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 107 |
1 files changed, 88 insertions, 19 deletions
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 5a480f0c3..79f1bcde2 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -10,27 +10,95 @@ from .expression import ( ) from . import operators from .visitors import VisitableType +from .. import util + +_registry = util.defaultdict(dict) class _GenericMeta(VisitableType): - def __call__(self, *args, **kwargs): - args = [_literal_as_binds(c) for c in args] - return type.__call__(self, *args, **kwargs) + def __init__(cls, clsname, bases, clsdict): + cls.name = name = clsdict.get('name', clsname) + package = clsdict.pop('package', '_default') + # legacy + if '__return_type__' in clsdict: + cls.type = clsdict['__return_type__'] + reg = _registry[package] + reg[name] = cls + super(_GenericMeta, cls).__init__(clsname, bases, clsdict) + + def __call__(cls, *args, **kwargs): + if cls.coerce_arguments: + args = [_literal_as_binds(c) for c in args] + return type.__call__(cls, *args, **kwargs) class GenericFunction(Function): + """Define a 'generic' function. + + A generic function is a pre-established :class:`.Function` + class that is instantiated automatically when called + by name from the :data:`.func` attribute. Note that + calling any name from :data:`.func` has the effect that + a new :class:`.Function` instance is created automatically, + given that name. The primary use case for defining + a :class:`.GenericFunction` class is so that a function + of a particular name may be given a fixed return type. + It can also include custom argument parsing schemes as well + as additional methods. + + Subclasses of :class:`.GenericFunction` are automatically + registered under the name of the class. For + example, a user-defined function ``as_utc()`` would + be available immediately:: + + from sqlalchemy.sql.functions import GenericFunction + from sqlalchemy.types import DateTime + + class as_utc(GenericFunction): + type = DateTime + + print select([func.as_utc()]) + + User-defined generic functions can be organized into + packages by specifying the "package" attribute when defining + :class:`.GenericFunction`. Third party libraries + containing many functions may want to use this in order + to avoid name conflicts with other systems. For example, + if our ``as_utc()`` function were part of a package + "time":: + + class as_utc(GenericFunction): + type = DateTime + package = "time" + + The above function would be available from :data:`.func` + using the package name ``time``:: + + print select([func.time.as_utc()]) + + .. versionadded:: 0.8 :class:`.GenericFunction` now supports + automatic registration of new functions as well as package + support. + + .. versionchanged:: 0.8 The attribute name ``type`` is used + to specify the function's return type at the class level. + Previously, the name ``__return_type__`` was used. This + name is still recognized for backwards-compatibility. + + """ __metaclass__ = _GenericMeta + coerce_arguments = True def __init__(self, type_=None, args=(), **kwargs): + args = [_literal_as_binds(c) for c in args] self.packagenames = [] - self.name = self.__class__.__name__ self._bind = kwargs.get('bind', None) self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *args).self_group() self.type = sqltypes.to_instance( - type_ or getattr(self, '__return_type__', None)) + type_ or getattr(self, 'type', None)) -class next_value(Function): +class next_value(GenericFunction): """Represent the 'next value', given a :class:`.Sequence` as it's single argument. @@ -41,6 +109,7 @@ class next_value(Function): """ type = sqltypes.Integer() name = "next_value" + coerce_arguments = False def __init__(self, seq, **kw): assert isinstance(seq, schema.Sequence), \ @@ -77,15 +146,15 @@ class sum(ReturnTypeFromArgs): class now(GenericFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class concat(GenericFunction): - __return_type__ = sqltypes.String + type = sqltypes.String def __init__(self, *args, **kwargs): GenericFunction.__init__(self, args=args, **kwargs) class char_length(GenericFunction): - __return_type__ = sqltypes.Integer + type = sqltypes.Integer def __init__(self, arg, **kwargs): GenericFunction.__init__(self, args=[arg], **kwargs) @@ -98,7 +167,7 @@ class random(GenericFunction): class count(GenericFunction): """The ANSI COUNT aggregate function. With no arguments, emits COUNT \*.""" - __return_type__ = sqltypes.Integer + type = sqltypes.Integer def __init__(self, expression=None, **kwargs): if expression is None: @@ -106,29 +175,29 @@ class count(GenericFunction): GenericFunction.__init__(self, args=(expression,), **kwargs) class current_date(AnsiFunction): - __return_type__ = sqltypes.Date + type = sqltypes.Date class current_time(AnsiFunction): - __return_type__ = sqltypes.Time + type = sqltypes.Time class current_timestamp(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class current_user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String class localtime(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class localtimestamp(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class session_user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String class sysdate(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime class user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String |