diff options
Diffstat (limited to 'lib/sqlalchemy/sql/functions.py')
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 139 |
1 files changed, 84 insertions, 55 deletions
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 4b4d2d463..883bb8cc3 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -10,10 +10,22 @@ """ from . import sqltypes, schema from .base import Executable, ColumnCollection -from .elements import ClauseList, Cast, Extract, _literal_as_binds, \ - literal_column, _type_from_args, ColumnElement, _clone,\ - Over, BindParameter, FunctionFilter, Grouping, WithinGroup, \ - BinaryExpression +from .elements import ( + ClauseList, + Cast, + Extract, + _literal_as_binds, + literal_column, + _type_from_args, + ColumnElement, + _clone, + Over, + BindParameter, + FunctionFilter, + Grouping, + WithinGroup, + BinaryExpression, +) from .selectable import FromClause, Select, Alias from . import util as sqlutil from . import operators @@ -62,9 +74,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): args = [_literal_as_binds(c, self.name) for c in clauses] self._has_args = self._has_args or bool(args) self.clause_expr = ClauseList( - operator=operators.comma_op, - group_contents=True, *args).\ - self_group() + operator=operators.comma_op, group_contents=True, *args + ).self_group() def _execute_on_connection(self, connection, multiparams, params): return connection._execute_function(self, multiparams, params) @@ -123,7 +134,7 @@ class FunctionElement(Executable, ColumnElement, FromClause): partition_by=partition_by, order_by=order_by, rows=rows, - range_=range_ + range_=range_, ) def within_group(self, *order_by): @@ -233,16 +244,14 @@ class FunctionElement(Executable, ColumnElement, FromClause): .. versionadded:: 1.3 """ - return FunctionAsBinary( - self, left_index, right_index - ) + return FunctionAsBinary(self, left_index, right_index) @property def _from_objects(self): return self.clauses._from_objects def get_children(self, **kwargs): - return self.clause_expr, + return (self.clause_expr,) def _copy_internals(self, clone=_clone, **kw): self.clause_expr = clone(self.clause_expr, **kw) @@ -336,24 +345,29 @@ class FunctionElement(Executable, ColumnElement, FromClause): return self.select().execute() def _bind_param(self, operator, obj, type_=None): - return BindParameter(None, obj, _compared_to_operator=operator, - _compared_to_type=self.type, unique=True, - type_=type_) + return BindParameter( + None, + obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + unique=True, + type_=type_, + ) def self_group(self, against=None): # for the moment, we are parenthesizing all array-returning # expressions against getitem. This may need to be made # more portable if in the future we support other DBs # besides postgresql. - if against is operators.getitem and \ - isinstance(self.type, sqltypes.ARRAY): + if against is operators.getitem and isinstance( + self.type, sqltypes.ARRAY + ): return Grouping(self) else: return super(FunctionElement, self).self_group(against=against) class FunctionAsBinary(BinaryExpression): - def __init__(self, fn, left_index, right_index): left = fn.clauses.clauses[left_index - 1] right = fn.clauses.clauses[right_index - 1] @@ -362,8 +376,11 @@ class FunctionAsBinary(BinaryExpression): self.right_index = right_index super(FunctionAsBinary, self).__init__( - left, right, operators.function_as_comparison_op, - type_=sqltypes.BOOLEANTYPE) + left, + right, + operators.function_as_comparison_op, + type_=sqltypes.BOOLEANTYPE, + ) @property def left(self): @@ -382,7 +399,7 @@ class FunctionAsBinary(BinaryExpression): self.sql_function.clauses.clauses[self.right_index - 1] = value def _copy_internals(self, **kw): - clone = kw.pop('clone') + clone = kw.pop("clone") self.sql_function = clone(self.sql_function, **kw) super(FunctionAsBinary, self)._copy_internals(**kw) @@ -396,13 +413,13 @@ class _FunctionGenerator(object): def __getattr__(self, name): # passthru __ attributes; fixes pydoc - if name.startswith('__'): + if name.startswith("__"): try: return self.__dict__[name] except KeyError: raise AttributeError(name) - elif name.endswith('_'): + elif name.endswith("_"): name = name[0:-1] f = _FunctionGenerator(**self.opts) f.__names = list(self.__names) + [name] @@ -426,8 +443,9 @@ class _FunctionGenerator(object): if func is not None: return func(*c, **o) - return Function(self.__names[-1], - packagenames=self.__names[0:-1], *c, **o) + return Function( + self.__names[-1], packagenames=self.__names[0:-1], *c, **o + ) func = _FunctionGenerator() @@ -523,7 +541,7 @@ class Function(FunctionElement): """ - __visit_name__ = 'function' + __visit_name__ = "function" def __init__(self, name, *clauses, **kw): """Construct a :class:`.Function`. @@ -532,30 +550,33 @@ class Function(FunctionElement): new :class:`.Function` instances. """ - self.packagenames = kw.pop('packagenames', None) or [] + self.packagenames = kw.pop("packagenames", None) or [] self.name = name - self._bind = kw.get('bind', None) - self.type = sqltypes.to_instance(kw.get('type_', None)) + self._bind = kw.get("bind", None) + self.type = sqltypes.to_instance(kw.get("type_", None)) FunctionElement.__init__(self, *clauses, **kw) def _bind_param(self, operator, obj, type_=None): - return BindParameter(self.name, obj, - _compared_to_operator=operator, - _compared_to_type=self.type, - type_=type_, - unique=True) + return BindParameter( + self.name, + obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + type_=type_, + unique=True, + ) class _GenericMeta(VisitableType): def __init__(cls, clsname, bases, clsdict): if annotation.Annotated not in cls.__mro__: - cls.name = name = clsdict.get('name', clsname) - cls.identifier = identifier = clsdict.get('identifier', name) - package = clsdict.pop('package', '_default') + cls.name = name = clsdict.get("name", clsname) + cls.identifier = identifier = clsdict.get("identifier", name) + package = clsdict.pop("package", "_default") # legacy - if '__return_type__' in clsdict: - cls.type = clsdict['__return_type__'] + if "__return_type__" in clsdict: + cls.type = clsdict["__return_type__"] register_function(identifier, cls, package) super(_GenericMeta, cls).__init__(clsname, bases, clsdict) @@ -635,17 +656,19 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): coerce_arguments = True def __init__(self, *args, **kwargs): - parsed_args = kwargs.pop('_parsed_args', None) + parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: parsed_args = [_literal_as_binds(c, self.name) for c in args] self._has_args = self._has_args or bool(parsed_args) self.packagenames = [] - self._bind = kwargs.get('bind', None) + self._bind = kwargs.get("bind", None) self.clause_expr = ClauseList( - operator=operators.comma_op, - group_contents=True, *parsed_args).self_group() + operator=operators.comma_op, group_contents=True, *parsed_args + ).self_group() self.type = sqltypes.to_instance( - kwargs.pop("type_", None) or getattr(self, 'type', None)) + kwargs.pop("type_", None) or getattr(self, "type", None) + ) + register_function("cast", Cast) register_function("extract", Extract) @@ -660,13 +683,15 @@ class next_value(GenericFunction): that does not provide support for sequences. """ + type = sqltypes.Integer() name = "next_value" def __init__(self, seq, **kw): - assert isinstance(seq, schema.Sequence), \ - "next_value() accepts a Sequence object as input." - self._bind = kw.get('bind', None) + assert isinstance( + seq, schema.Sequence + ), "next_value() accepts a Sequence object as input." + self._bind = kw.get("bind", None) self.sequence = seq @property @@ -684,8 +709,8 @@ class ReturnTypeFromArgs(GenericFunction): def __init__(self, *args, **kwargs): args = [_literal_as_binds(c, self.name) for c in args] - kwargs.setdefault('type_', _type_from_args(args)) - kwargs['_parsed_args'] = args + kwargs.setdefault("type_", _type_from_args(args)) + kwargs["_parsed_args"] = args super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) @@ -733,7 +758,7 @@ class count(GenericFunction): def __init__(self, expression=None, **kwargs): if expression is None: - expression = literal_column('*') + expression = literal_column("*") super(count, self).__init__(expression, **kwargs) @@ -797,15 +822,15 @@ class array_agg(GenericFunction): def __init__(self, *args, **kwargs): args = [_literal_as_binds(c) for c in args] - default_array_type = kwargs.pop('_default_array_type', sqltypes.ARRAY) - if 'type_' not in kwargs: + default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) + if "type_" not in kwargs: type_from_args = _type_from_args(args) if isinstance(type_from_args, sqltypes.ARRAY): - kwargs['type_'] = type_from_args + kwargs["type_"] = type_from_args else: - kwargs['type_'] = default_array_type(type_from_args) - kwargs['_parsed_args'] = args + kwargs["type_"] = default_array_type(type_from_args) + kwargs["_parsed_args"] = args super(array_agg, self).__init__(*args, **kwargs) @@ -883,6 +908,7 @@ class rank(GenericFunction): .. versionadded:: 1.1 """ + type = sqltypes.Integer() @@ -897,6 +923,7 @@ class dense_rank(GenericFunction): .. versionadded:: 1.1 """ + type = sqltypes.Integer() @@ -911,6 +938,7 @@ class percent_rank(GenericFunction): .. versionadded:: 1.1 """ + type = sqltypes.Numeric() @@ -925,6 +953,7 @@ class cume_dist(GenericFunction): .. versionadded:: 1.1 """ + type = sqltypes.Numeric() |