diff options
Diffstat (limited to 'lib/sqlalchemy/sql/functions.py')
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 70 |
1 files changed, 25 insertions, 45 deletions
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index cbc8e539f..96e64dc28 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -17,7 +17,6 @@ from . import sqltypes from . import util as sqlutil from .base import ColumnCollection from .base import Executable -from .elements import _clone from .elements import _type_from_args from .elements import BinaryExpression from .elements import BindParameter @@ -33,7 +32,8 @@ from .elements import WithinGroup from .selectable import Alias from .selectable import FromClause from .selectable import Select -from .visitors import VisitableType +from .visitors import InternalTraversal +from .visitors import TraversibleType from .. import util @@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause): """ + _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)] + packagenames = () _has_args = False + _memoized_property = FromClause._memoized_property + def __init__(self, *clauses, **kwargs): r"""Construct a :class:`.FunctionElement`. @@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause): col = self.label(None) return ColumnCollection(columns=[(col.key, col)]) - @util.memoized_property + @_memoized_property def clauses(self): """Return the underlying :class:`.ClauseList` which contains the arguments for this :class:`.FunctionElement`. @@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause): def _from_objects(self): return self.clauses._from_objects - def get_children(self, **kwargs): - return (self.clause_expr,) - - def _cache_key(self, **kw): - return (FunctionElement, self.clause_expr._cache_key(**kw)) - - def _copy_internals(self, clone=_clone, **kw): - self.clause_expr = clone(self.clause_expr, **kw) - self._reset_exported() - FunctionElement.clauses._reset(self) - def within_group_type(self, within_group): """For types that define their return type as based on the criteria within a WITHIN GROUP (ORDER BY) expression, called by the @@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause): class FunctionAsBinary(BinaryExpression): + _traverse_internals = [ + ("sql_function", InternalTraversal.dp_clauseelement), + ("left_index", InternalTraversal.dp_plain_obj), + ("right_index", InternalTraversal.dp_plain_obj), + ("modifiers", InternalTraversal.dp_plain_dict), + ] + def __init__(self, fn, left_index, right_index): self.sql_function = fn self.left_index = left_index @@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression): def right(self, value): self.sql_function.clauses.clauses[self.right_index - 1] = value - def _copy_internals(self, clone=_clone, **kw): - self.sql_function = clone(self.sql_function, **kw) - - def get_children(self, **kw): - yield self.sql_function - - def _cache_key(self, **kw): - return ( - FunctionAsBinary, - self.sql_function._cache_key(**kw), - self.left_index, - self.right_index, - ) - class _FunctionGenerator(object): """Generate SQL function expressions. @@ -606,6 +592,12 @@ class Function(FunctionElement): __visit_name__ = "function" + _traverse_internals = FunctionElement._traverse_internals + [ + ("packagenames", InternalTraversal.dp_plain_obj), + ("name", InternalTraversal.dp_string), + ("type", InternalTraversal.dp_type), + ] + def __init__(self, name, *clauses, **kw): """Construct a :class:`.Function`. @@ -630,15 +622,8 @@ class Function(FunctionElement): unique=True, ) - def _cache_key(self, **kw): - return ( - (Function,) + tuple(self.packagenames) - if self.packagenames - else () + (self.name, self.clause_expr._cache_key(**kw)) - ) - -class _GenericMeta(VisitableType): +class _GenericMeta(TraversibleType): def __init__(cls, clsname, bases, clsdict): if annotation.Annotated not in cls.__mro__: cls.name = name = clsdict.get("name", clsname) @@ -764,6 +749,10 @@ class next_value(GenericFunction): type = sqltypes.Integer() name = "next_value" + _traverse_internals = [ + ("sequence", InternalTraversal.dp_named_ddl_element) + ] + def __init__(self, seq, **kw): assert isinstance( seq, schema.Sequence @@ -771,21 +760,12 @@ class next_value(GenericFunction): self._bind = kw.get("bind", None) self.sequence = seq - def _cache_key(self, **kw): - return (next_value, self.sequence.name) - def compare(self, other, **kw): return ( isinstance(other, next_value) and self.sequence.name == other.sequence.name ) - def get_children(self, **kwargs): - return [] - - def _copy_internals(self, **kw): - pass - @property def _from_objects(self): return [] |