summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/functions.py')
-rw-r--r--lib/sqlalchemy/sql/functions.py70
1 files changed, 25 insertions, 45 deletions
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index cbc8e539f..96e64dc28 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -17,7 +17,6 @@ from . import sqltypes
from . import util as sqlutil
from .base import ColumnCollection
from .base import Executable
-from .elements import _clone
from .elements import _type_from_args
from .elements import BinaryExpression
from .elements import BindParameter
@@ -33,7 +32,8 @@ from .elements import WithinGroup
from .selectable import Alias
from .selectable import FromClause
from .selectable import Select
-from .visitors import VisitableType
+from .visitors import InternalTraversal
+from .visitors import TraversibleType
from .. import util
@@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
+ _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)]
+
packagenames = ()
_has_args = False
+ _memoized_property = FromClause._memoized_property
+
def __init__(self, *clauses, **kwargs):
r"""Construct a :class:`.FunctionElement`.
@@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
col = self.label(None)
return ColumnCollection(columns=[(col.key, col)])
- @util.memoized_property
+ @_memoized_property
def clauses(self):
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
@@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause):
def _from_objects(self):
return self.clauses._from_objects
- def get_children(self, **kwargs):
- return (self.clause_expr,)
-
- def _cache_key(self, **kw):
- return (FunctionElement, self.clause_expr._cache_key(**kw))
-
- def _copy_internals(self, clone=_clone, **kw):
- self.clause_expr = clone(self.clause_expr, **kw)
- self._reset_exported()
- FunctionElement.clauses._reset(self)
-
def within_group_type(self, within_group):
"""For types that define their return type as based on the criteria
within a WITHIN GROUP (ORDER BY) expression, called by the
@@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause):
class FunctionAsBinary(BinaryExpression):
+ _traverse_internals = [
+ ("sql_function", InternalTraversal.dp_clauseelement),
+ ("left_index", InternalTraversal.dp_plain_obj),
+ ("right_index", InternalTraversal.dp_plain_obj),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
def __init__(self, fn, left_index, right_index):
self.sql_function = fn
self.left_index = left_index
@@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression):
def right(self, value):
self.sql_function.clauses.clauses[self.right_index - 1] = value
- def _copy_internals(self, clone=_clone, **kw):
- self.sql_function = clone(self.sql_function, **kw)
-
- def get_children(self, **kw):
- yield self.sql_function
-
- def _cache_key(self, **kw):
- return (
- FunctionAsBinary,
- self.sql_function._cache_key(**kw),
- self.left_index,
- self.right_index,
- )
-
class _FunctionGenerator(object):
"""Generate SQL function expressions.
@@ -606,6 +592,12 @@ class Function(FunctionElement):
__visit_name__ = "function"
+ _traverse_internals = FunctionElement._traverse_internals + [
+ ("packagenames", InternalTraversal.dp_plain_obj),
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ]
+
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
@@ -630,15 +622,8 @@ class Function(FunctionElement):
unique=True,
)
- def _cache_key(self, **kw):
- return (
- (Function,) + tuple(self.packagenames)
- if self.packagenames
- else () + (self.name, self.clause_expr._cache_key(**kw))
- )
-
-class _GenericMeta(VisitableType):
+class _GenericMeta(TraversibleType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
cls.name = name = clsdict.get("name", clsname)
@@ -764,6 +749,10 @@ class next_value(GenericFunction):
type = sqltypes.Integer()
name = "next_value"
+ _traverse_internals = [
+ ("sequence", InternalTraversal.dp_named_ddl_element)
+ ]
+
def __init__(self, seq, **kw):
assert isinstance(
seq, schema.Sequence
@@ -771,21 +760,12 @@ class next_value(GenericFunction):
self._bind = kw.get("bind", None)
self.sequence = seq
- def _cache_key(self, **kw):
- return (next_value, self.sequence.name)
-
def compare(self, other, **kw):
return (
isinstance(other, next_value)
and self.sequence.name == other.sequence.name
)
- def get_children(self, **kwargs):
- return []
-
- def _copy_internals(self, **kw):
- pass
-
@property
def _from_objects(self):
return []