summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/functions.py')
-rw-r--r--lib/sqlalchemy/sql/functions.py139
1 files changed, 84 insertions, 55 deletions
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 4b4d2d463..883bb8cc3 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -10,10 +10,22 @@
"""
from . import sqltypes, schema
from .base import Executable, ColumnCollection
-from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
- literal_column, _type_from_args, ColumnElement, _clone,\
- Over, BindParameter, FunctionFilter, Grouping, WithinGroup, \
- BinaryExpression
+from .elements import (
+ ClauseList,
+ Cast,
+ Extract,
+ _literal_as_binds,
+ literal_column,
+ _type_from_args,
+ ColumnElement,
+ _clone,
+ Over,
+ BindParameter,
+ FunctionFilter,
+ Grouping,
+ WithinGroup,
+ BinaryExpression,
+)
from .selectable import FromClause, Select, Alias
from . import util as sqlutil
from . import operators
@@ -62,9 +74,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
args = [_literal_as_binds(c, self.name) for c in clauses]
self._has_args = self._has_args or bool(args)
self.clause_expr = ClauseList(
- operator=operators.comma_op,
- group_contents=True, *args).\
- self_group()
+ operator=operators.comma_op, group_contents=True, *args
+ ).self_group()
def _execute_on_connection(self, connection, multiparams, params):
return connection._execute_function(self, multiparams, params)
@@ -123,7 +134,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
partition_by=partition_by,
order_by=order_by,
rows=rows,
- range_=range_
+ range_=range_,
)
def within_group(self, *order_by):
@@ -233,16 +244,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
.. versionadded:: 1.3
"""
- return FunctionAsBinary(
- self, left_index, right_index
- )
+ return FunctionAsBinary(self, left_index, right_index)
@property
def _from_objects(self):
return self.clauses._from_objects
def get_children(self, **kwargs):
- return self.clause_expr,
+ return (self.clause_expr,)
def _copy_internals(self, clone=_clone, **kw):
self.clause_expr = clone(self.clause_expr, **kw)
@@ -336,24 +345,29 @@ class FunctionElement(Executable, ColumnElement, FromClause):
return self.select().execute()
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(None, obj, _compared_to_operator=operator,
- _compared_to_type=self.type, unique=True,
- type_=type_)
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ unique=True,
+ type_=type_,
+ )
def self_group(self, against=None):
# for the moment, we are parenthesizing all array-returning
# expressions against getitem. This may need to be made
# more portable if in the future we support other DBs
# besides postgresql.
- if against is operators.getitem and \
- isinstance(self.type, sqltypes.ARRAY):
+ if against is operators.getitem and isinstance(
+ self.type, sqltypes.ARRAY
+ ):
return Grouping(self)
else:
return super(FunctionElement, self).self_group(against=against)
class FunctionAsBinary(BinaryExpression):
-
def __init__(self, fn, left_index, right_index):
left = fn.clauses.clauses[left_index - 1]
right = fn.clauses.clauses[right_index - 1]
@@ -362,8 +376,11 @@ class FunctionAsBinary(BinaryExpression):
self.right_index = right_index
super(FunctionAsBinary, self).__init__(
- left, right, operators.function_as_comparison_op,
- type_=sqltypes.BOOLEANTYPE)
+ left,
+ right,
+ operators.function_as_comparison_op,
+ type_=sqltypes.BOOLEANTYPE,
+ )
@property
def left(self):
@@ -382,7 +399,7 @@ class FunctionAsBinary(BinaryExpression):
self.sql_function.clauses.clauses[self.right_index - 1] = value
def _copy_internals(self, **kw):
- clone = kw.pop('clone')
+ clone = kw.pop("clone")
self.sql_function = clone(self.sql_function, **kw)
super(FunctionAsBinary, self)._copy_internals(**kw)
@@ -396,13 +413,13 @@ class _FunctionGenerator(object):
def __getattr__(self, name):
# passthru __ attributes; fixes pydoc
- if name.startswith('__'):
+ if name.startswith("__"):
try:
return self.__dict__[name]
except KeyError:
raise AttributeError(name)
- elif name.endswith('_'):
+ elif name.endswith("_"):
name = name[0:-1]
f = _FunctionGenerator(**self.opts)
f.__names = list(self.__names) + [name]
@@ -426,8 +443,9 @@ class _FunctionGenerator(object):
if func is not None:
return func(*c, **o)
- return Function(self.__names[-1],
- packagenames=self.__names[0:-1], *c, **o)
+ return Function(
+ self.__names[-1], packagenames=self.__names[0:-1], *c, **o
+ )
func = _FunctionGenerator()
@@ -523,7 +541,7 @@ class Function(FunctionElement):
"""
- __visit_name__ = 'function'
+ __visit_name__ = "function"
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
@@ -532,30 +550,33 @@ class Function(FunctionElement):
new :class:`.Function` instances.
"""
- self.packagenames = kw.pop('packagenames', None) or []
+ self.packagenames = kw.pop("packagenames", None) or []
self.name = name
- self._bind = kw.get('bind', None)
- self.type = sqltypes.to_instance(kw.get('type_', None))
+ self._bind = kw.get("bind", None)
+ self.type = sqltypes.to_instance(kw.get("type_", None))
FunctionElement.__init__(self, *clauses, **kw)
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(self.name, obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type,
- type_=type_,
- unique=True)
+ return BindParameter(
+ self.name,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ )
class _GenericMeta(VisitableType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
- cls.name = name = clsdict.get('name', clsname)
- cls.identifier = identifier = clsdict.get('identifier', name)
- package = clsdict.pop('package', '_default')
+ cls.name = name = clsdict.get("name", clsname)
+ cls.identifier = identifier = clsdict.get("identifier", name)
+ package = clsdict.pop("package", "_default")
# legacy
- if '__return_type__' in clsdict:
- cls.type = clsdict['__return_type__']
+ if "__return_type__" in clsdict:
+ cls.type = clsdict["__return_type__"]
register_function(identifier, cls, package)
super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
@@ -635,17 +656,19 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
coerce_arguments = True
def __init__(self, *args, **kwargs):
- parsed_args = kwargs.pop('_parsed_args', None)
+ parsed_args = kwargs.pop("_parsed_args", None)
if parsed_args is None:
parsed_args = [_literal_as_binds(c, self.name) for c in args]
self._has_args = self._has_args or bool(parsed_args)
self.packagenames = []
- self._bind = kwargs.get('bind', None)
+ self._bind = kwargs.get("bind", None)
self.clause_expr = ClauseList(
- operator=operators.comma_op,
- group_contents=True, *parsed_args).self_group()
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ ).self_group()
self.type = sqltypes.to_instance(
- kwargs.pop("type_", None) or getattr(self, 'type', None))
+ kwargs.pop("type_", None) or getattr(self, "type", None)
+ )
+
register_function("cast", Cast)
register_function("extract", Extract)
@@ -660,13 +683,15 @@ class next_value(GenericFunction):
that does not provide support for sequences.
"""
+
type = sqltypes.Integer()
name = "next_value"
def __init__(self, seq, **kw):
- assert isinstance(seq, schema.Sequence), \
- "next_value() accepts a Sequence object as input."
- self._bind = kw.get('bind', None)
+ assert isinstance(
+ seq, schema.Sequence
+ ), "next_value() accepts a Sequence object as input."
+ self._bind = kw.get("bind", None)
self.sequence = seq
@property
@@ -684,8 +709,8 @@ class ReturnTypeFromArgs(GenericFunction):
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c, self.name) for c in args]
- kwargs.setdefault('type_', _type_from_args(args))
- kwargs['_parsed_args'] = args
+ kwargs.setdefault("type_", _type_from_args(args))
+ kwargs["_parsed_args"] = args
super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
@@ -733,7 +758,7 @@ class count(GenericFunction):
def __init__(self, expression=None, **kwargs):
if expression is None:
- expression = literal_column('*')
+ expression = literal_column("*")
super(count, self).__init__(expression, **kwargs)
@@ -797,15 +822,15 @@ class array_agg(GenericFunction):
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c) for c in args]
- default_array_type = kwargs.pop('_default_array_type', sqltypes.ARRAY)
- if 'type_' not in kwargs:
+ default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
+ if "type_" not in kwargs:
type_from_args = _type_from_args(args)
if isinstance(type_from_args, sqltypes.ARRAY):
- kwargs['type_'] = type_from_args
+ kwargs["type_"] = type_from_args
else:
- kwargs['type_'] = default_array_type(type_from_args)
- kwargs['_parsed_args'] = args
+ kwargs["type_"] = default_array_type(type_from_args)
+ kwargs["_parsed_args"] = args
super(array_agg, self).__init__(*args, **kwargs)
@@ -883,6 +908,7 @@ class rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Integer()
@@ -897,6 +923,7 @@ class dense_rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Integer()
@@ -911,6 +938,7 @@ class percent_rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Numeric()
@@ -925,6 +953,7 @@ class cume_dist(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Numeric()