diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-12-16 17:06:43 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-07-03 23:39:51 -0400 |
commit | 3dc9a4a2392d033f9d1bd79dd6b6ecea6281a61c (patch) | |
tree | 1041bccb37422f526dccb5b1e57ffad1c702549b /lib/sqlalchemy/sql | |
parent | 5060043e8e95ab0aab5f63ed288c1426c46da66e (diff) | |
download | sqlalchemy-3dc9a4a2392d033f9d1bd79dd6b6ecea6281a61c.tar.gz |
introduce deferred lambdas
The coercions system allows us to add in lambdas as arguments
to Core and ORM elements without changing them at all. By allowing
the lambda to produce a deterministic cache key where we can also
cheat and yank out literal parameters means we can move towards
having 90% of "baked" functionality in a clearer way right in
Core / ORM.
As a second step, we can have whole statements inside the lambda,
and can then add generation with __add__(), so then we have
100% of "baked" functionality with full support of ad-hoc
literal values.
Adds some more short_selects tests for the moment for comparison.
Other tweaks inside cache key generation as we're trying to
approach a certain level of performance such that we can
remove the use of "baked" from the loader strategies.
As we have not yet closed #4639, however the caching feature
has been fully integrated as of
b0cfa7379cf8513a821a3dbe3028c4965d9f85bd, we will also
add complete caching documentation here and close that issue
as well.
Closes: #4639
Fixes: #5380
Change-Id: If91f61527236fd4d7ae3cad1f24c38be921c90ba
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 86 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 607 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/roles.py | 28 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 91 |
11 files changed, 817 insertions, 82 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index a25c1b083..2fe6f35d2 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -46,6 +46,8 @@ from .expression import intersect_all # noqa from .expression import Join # noqa from .expression import join # noqa from .expression import label # noqa +from .expression import lambda_stmt # noqa +from .expression import LambdaElement # noqa from .expression import lateral # noqa from .expression import literal # noqa from .expression import literal_column # noqa @@ -62,6 +64,7 @@ from .expression import quoted_name # noqa from .expression import Select # noqa from .expression import select # noqa from .expression import Selectable # noqa +from .expression import StatementLambdaElement # noqa from .expression import Subquery # noqa from .expression import subquery # noqa from .expression import table # noqa @@ -106,18 +109,22 @@ def __go(lcls): from . import coercions from . import elements from . import events # noqa + from . import lambdas from . import selectable from . import schema from . import sqltypes + from . import traversals from . import type_api base.coercions = elements.coercions = coercions base.elements = elements base.type_api = type_api coercions.elements = elements + coercions.lambdas = lambdas coercions.schema = schema coercions.selectable = selectable coercions.sqltypes = sqltypes + coercions.traversals = traversals _prepare_annotations(ColumnElement, AnnotatedColumnElement) _prepare_annotations(FromClause, AnnotatedFromClause) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 9dcd7dca9..6cdab8eac 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -14,6 +14,7 @@ import itertools import operator import re +from . import roles from .traversals import HasCacheKey # noqa from .traversals import MemoizedHasCacheKey # noqa from .visitors import ClauseVisitor @@ -447,13 +448,17 @@ class CompileState(object): "compile_state_plugin", "default" ) klass = cls.plugins.get( - (plugin_name, statement.__visit_name__), None + (plugin_name, statement._effective_plugin_target), None ) if klass is None: - klass = cls.plugins[("default", statement.__visit_name__)] + klass = cls.plugins[ + ("default", statement._effective_plugin_target) + ] else: - klass = cls.plugins[("default", statement.__visit_name__)] + klass = cls.plugins[ + ("default", statement._effective_plugin_target) + ] if klass is cls: return cls(statement, compiler, **kw) @@ -469,14 +474,18 @@ class CompileState(object): "compile_state_plugin", "default" ) try: - return cls.plugins[(plugin_name, statement.__visit_name__)] + return cls.plugins[ + (plugin_name, statement._effective_plugin_target) + ] except KeyError: return None @classmethod def _get_plugin_class_for_plugin(cls, statement, plugin_name): try: - return cls.plugins[(plugin_name, statement.__visit_name__)] + return cls.plugins[ + (plugin_name, statement._effective_plugin_target) + ] except KeyError: return None @@ -637,6 +646,10 @@ class Executable(Generative): ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs), ] + @property + def _effective_plugin_target(self): + return self.__visit_name__ + @_generative def options(self, *options): """Apply options to this statement. @@ -667,7 +680,9 @@ class Executable(Generative): to the usage of ORM queries """ - self._with_options += options + self._with_options += tuple( + coercions.expect(roles.HasCacheKeyRole, opt) for opt in options + ) @_generative def _add_context_option(self, callable_, cache_args): diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 4c6a0317a..be412c770 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -21,9 +21,11 @@ if util.TYPE_CHECKING: from types import ModuleType elements = None # type: ModuleType +lambdas = None # type: ModuleType schema = None # type: ModuleType selectable = None # type: ModuleType sqltypes = None # type: ModuleType +traversals = None # type: ModuleType def _is_literal(element): @@ -51,6 +53,23 @@ def _document_text_coercion(paramname, meth_rst, param_rst): def expect(role, element, apply_propagate_attrs=None, argname=None, **kw): + if ( + role.allows_lambda + # note callable() will not invoke a __getattr__() method, whereas + # hasattr(obj, "__call__") will. by keeping the callable() check here + # we prevent most needless calls to hasattr() and therefore + # __getattr__(), which is present on ColumnElement. + and callable(element) + and hasattr(element, "__code__") + ): + return lambdas.LambdaElement( + element, + role, + apply_propagate_attrs=apply_propagate_attrs, + argname=argname, + **kw + ) + # major case is that we are given a ClauseElement already, skip more # elaborate logic up front if possible impl = _impl_lookup[role] @@ -106,7 +125,12 @@ def expect(role, element, apply_propagate_attrs=None, argname=None, **kw): if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: - resolved = impl._post_coercion(resolved, argname=argname, **kw) + resolved = impl._post_coercion( + resolved, + argname=argname, + original_element=original_element, + **kw + ) return resolved else: return impl._implicit_coercions( @@ -230,6 +254,8 @@ class _ColumnCoercions(object): ): self._warn_for_scalar_subquery_coercion() return resolved.element.scalar_subquery() + elif self._role_class.allows_lambda and resolved._is_lambda_element: + return resolved else: self._raise_for_expected(original_element, argname, resolved) @@ -319,6 +345,21 @@ class _SelectIsNotFrom(object): ) +class HasCacheKeyImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, traversals.HasCacheKey): + return original_element + else: + self._raise_for_expected(original_element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + class ExpressionElementImpl(_ColumnCoercions, RoleImpl): __slots__ = () @@ -420,7 +461,14 @@ class InElementImpl(RoleImpl): assert not len(element.clauses) == 0 return element.self_group(against=operator) - elif isinstance(element, elements.BindParameter) and element.expanding: + elif isinstance(element, elements.BindParameter): + if not element.expanding: + # coercing to expanding at the moment to work with the + # lambda system. not sure if this is the right approach. + # is there a valid use case to send a single non-expanding + # param to IN? check for ARRAY type? + element = element._clone(maintain_key=True) + element.expanding = True if isinstance(expr, elements.Tuple): element = element._with_expanding_in_types( [elem.type for elem in expr] @@ -431,6 +479,22 @@ class InElementImpl(RoleImpl): return element +class OnClauseImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _post_coercion(self, resolved, original_element=None, **kw): + # this is a hack right now as we want to use coercion on an + # ORM InstrumentedAttribute, but we want to return the object + # itself if it is one, not its clause element. + # ORM context _join and _legacy_join() would need to be improved + # to look for annotations in a clause element form. + if isinstance(original_element, roles.JoinTargetRole): + return original_element + return resolved + + class WhereHavingImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl): __slots__ = () @@ -635,6 +699,24 @@ class StatementImpl(_NoTextCoercion, RoleImpl): class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl): __slots__ = () + def _literal_coercion(self, element, **kw): + if callable(element) and hasattr(element, "__code__"): + return lambdas.StatementLambdaElement(element, self._role_class) + else: + return super(CoerceTextStatementImpl, self)._literal_coercion( + element, **kw + ) + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_lambda_element: + return resolved + else: + return super(CoerceTextStatementImpl, self)._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + def _text_coercion(self, element, argname=None): # TODO: this should emit deprecation warning, # see deprecation warning in engine/base.py execute() diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6152a28e7..3a3ce5c45 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1296,6 +1296,10 @@ class SQLCompiler(Compiled): "Cannot compile Column object until " "its 'name' is assigned." ) + def visit_lambda_element(self, element, **kw): + sql_element = element._resolved + return self.process(sql_element, **kw) + def visit_column( self, column, @@ -1624,7 +1628,7 @@ class SQLCompiler(Compiled): return func.clause_expr._compiler_dispatch(self, **kwargs) def visit_compound_select( - self, cs, asfrom=False, compound_index=0, **kwargs + self, cs, asfrom=False, compound_index=None, **kwargs ): toplevel = not self.stack @@ -1635,10 +1639,14 @@ class SQLCompiler(Compiled): entry = self._default_stack_entry if toplevel else self.stack[-1] need_result_map = toplevel or ( - compound_index == 0 + not compound_index and entry.get("need_result_map_for_compound", False) ) + # indicates there is already a CompoundSelect in play + if compound_index == 0: + entry["select_0"] = cs + self.stack.append( { "correlate_froms": entry["correlate_froms"], @@ -2654,7 +2662,7 @@ class SQLCompiler(Compiled): select_stmt, asfrom=False, fromhints=None, - compound_index=0, + compound_index=None, select_wraps_for=None, lateral=False, from_linter=None, @@ -2709,7 +2717,9 @@ class SQLCompiler(Compiled): or entry.get("need_result_map_for_nested", False) ) - if compound_index > 0: + # indicates there is a CompoundSelect in play and we are not the + # first select + if compound_index: populate_result_map = False # this was first proposed as part of #3372; however, it is not @@ -2844,11 +2854,10 @@ class SQLCompiler(Compiled): correlate_froms = entry["correlate_froms"] asfrom_froms = entry["asfrom_froms"] - if compound_index > 0: - # note this is cached - select_0 = entry["selectable"].selects[0] - if select_0._is_select_container: - select_0 = select_0.element + if compound_index == 0: + entry["select_0"] = select + elif compound_index: + select_0 = entry["select_0"] numcols = len(select_0.selected_columns) if len(compile_state.columns_plus_names) != numcols: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index af5eab257..6ce505412 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -215,6 +215,7 @@ class ClauseElement( _is_select_statement = False _is_bind_parameter = False _is_clause_list = False + _is_lambda_element = False _order_by_label_element = None @@ -1337,9 +1338,6 @@ class BindParameter(roles.InElementRole, ColumnElement): :ref:`change_4808`. - - - """ if required is NO_ARG: @@ -1406,15 +1404,15 @@ class BindParameter(roles.InElementRole, ColumnElement): the context of an expanding IN against a tuple. """ - cloned = self._clone() + cloned = self._clone(maintain_key=True) cloned._expanding_in_types = types return cloned - def _with_value(self, value): + def _with_value(self, value, maintain_key=False): """Return a copy of this :class:`.BindParameter` with the given value set. """ - cloned = self._clone() + cloned = self._clone(maintain_key=maintain_key) cloned.value = value cloned.callable = None cloned.required = False @@ -1442,9 +1440,9 @@ class BindParameter(roles.InElementRole, ColumnElement): c.type = type_ return c - def _clone(self): + def _clone(self, maintain_key=False): c = ClauseElement._clone(self) - if self.unique: + if not maintain_key and self.unique: c.key = _anonymous_label( "%%(%d %s)s" % (id(c), c._orig_key or "param") ) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index e25063372..37441a125 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -29,6 +29,8 @@ __all__ = [ "Insert", "Join", "Lateral", + "LambdaElement", + "StatementLambdaElement", "Select", "Selectable", "TableClause", @@ -59,6 +61,7 @@ __all__ = [ "join", "label", "lateral", + "lambda_stmt", "literal", "literal_column", "not_", @@ -135,6 +138,9 @@ from .functions import func # noqa from .functions import Function # noqa from .functions import FunctionElement # noqa from .functions import modifier # noqa +from .lambdas import lambda_stmt # noqa +from .lambdas import LambdaElement # noqa +from .lambdas import StatementLambdaElement # noqa from .selectable import Alias # noqa from .selectable import AliasedReturnsRows # noqa from .selectable import CompoundSelect # noqa diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6fff26842..c1b8bbd27 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -614,7 +614,7 @@ class Function(FunctionElement): new :class:`.Function` instances. """ - self.packagenames = kw.pop("packagenames", None) or [] + self.packagenames = kw.pop("packagenames", None) or () self.name = name self._bind = kw.get("bind", None) self.type = sqltypes.to_instance(kw.get("type_", None)) @@ -759,7 +759,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): for c in args ] self._has_args = self._has_args or bool(parsed_args) - self.packagenames = [] + self.packagenames = () self._bind = kwargs.get("bind", None) self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *parsed_args diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py new file mode 100644 index 000000000..792411189 --- /dev/null +++ b/lib/sqlalchemy/sql/lambdas.py @@ -0,0 +1,607 @@ +# sql/lambdas.py +# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import itertools +import operator +import sys +import weakref + +from . import coercions +from . import elements +from . import roles +from . import schema +from . import traversals +from . import type_api +from . import visitors +from .operators import ColumnOperators +from .. import exc +from .. import inspection +from .. import util +from ..util import collections_abc + +_trackers = weakref.WeakKeyDictionary() + + +_TRACKERS = 0 +_STALE_CHECK = 1 +_REAL_FN = 2 +_EXPR = 3 +_IS_SEQUENCE = 4 +_PROPAGATE_ATTRS = 5 + + +def lambda_stmt(lmb): + """Produce a SQL statement that is cached as a lambda. + + This SQL statement will only be constructed if element has not been + compiled yet. The approach is used to save on Python function overhead + when constructing statements that will be cached. + + E.g.:: + + from sqlalchemy import lambda_stmt + + stmt = lambda_stmt(lambda: table.select()) + stmt += lambda s: s.where(table.c.id == 5) + + result = connection.execute(stmt) + + The object returned is an instance of :class:`_sql.StatementLambdaElement`. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`engine_lambda_caching` + + + """ + return coercions.expect(roles.CoerceTextStatementRole, lmb) + + +class LambdaElement(elements.ClauseElement): + """A SQL construct where the state is stored as an un-invoked lambda. + + The :class:`_sql.LambdaElement` is produced transparently whenever + passing lambda expressions into SQL constructs, such as:: + + stmt = select(table).where(lambda: table.c.col == parameter) + + The :class:`_sql.LambdaElement` is the base of the + :class:`_sql.StatementLambdaElement` which represents a full statement + within a lambda. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`engine_lambda_caching` + + """ + + __visit_name__ = "lambda_element" + + _is_lambda_element = True + + _resolved_bindparams = () + + _traverse_internals = [ + ("_resolved", visitors.InternalTraversal.dp_clauseelement) + ] + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) + + def __init__(self, fn, role, apply_propagate_attrs=None, **kw): + self.fn = fn + self.role = role + self.parent_lambda = None + + if apply_propagate_attrs is None and ( + role is roles.CoerceTextStatementRole + ): + apply_propagate_attrs = self + + if fn.__code__ not in _trackers: + rec = self._initialize_var_trackers( + role, apply_propagate_attrs, kw + ) + else: + rec = _trackers[self.fn.__code__] + closure = fn.__closure__ + + # check if the objects fixed inside the lambda that we've cached + # have been changed. This can apply to things like mappers that + # were recreated in test suites. if so, re-initialize. + # + # this is a small performance hit on every use for a not very + # common situation, however it's very hard to debug if the + # condition does occur. + for idx, obj in rec[_STALE_CHECK]: + if closure[idx].cell_contents is not obj: + rec = self._initialize_var_trackers( + role, apply_propagate_attrs, kw + ) + break + self._rec = rec + + if apply_propagate_attrs is not None: + propagate_attrs = rec[_PROPAGATE_ATTRS] + if propagate_attrs: + apply_propagate_attrs._propagate_attrs = propagate_attrs + + if rec[_TRACKERS]: + self._resolved_bindparams = bindparams = [] + for tracker in rec[_TRACKERS]: + tracker(self.fn, bindparams) + + def __getattr__(self, key): + return getattr(self._rec[_EXPR], key) + + @property + def _is_sequence(self): + return self._rec[_IS_SEQUENCE] + + @property + def _select_iterable(self): + if self._is_sequence: + return itertools.chain.from_iterable( + [element._select_iterable for element in self._resolved] + ) + + else: + return self._resolved._select_iterable + + @property + def _from_objects(self): + if self._is_sequence: + return itertools.chain.from_iterable( + [element._from_objects for element in self._resolved] + ) + + else: + return self._resolved._from_objects + + def _param_dict(self): + return {b.key: b.value for b in self._resolved_bindparams} + + @util.memoized_property + def _resolved(self): + bindparam_lookup = {b.key: b for b in self._resolved_bindparams} + + def replace(thing): + if ( + isinstance(thing, elements.BindParameter) + and thing.key in bindparam_lookup + ): + bind = bindparam_lookup[thing.key] + # TODO: consider + # if we should clone the bindparam here, re-cache the new + # version, etc. also we make an assumption about "expanding" + # in this case. + if thing.expanding: + bind.expanding = True + return bind + + expr = self._rec[_EXPR] + + if self._rec[_IS_SEQUENCE]: + expr = [ + visitors.replacement_traverse(sub_expr, {}, replace) + for sub_expr in expr + ] + elif getattr(expr, "is_clause_element", False): + expr = visitors.replacement_traverse(expr, {}, replace) + + return expr + + def _gen_cache_key(self, anon_map, bindparams): + + cache_key = (self.fn.__code__, self.__class__) + + if self._resolved_bindparams: + bindparams.extend(self._resolved_bindparams) + + return cache_key + + def _invoke_user_fn(self, fn, *arg): + return fn() + + def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw): + fn = self.fn + + # track objects referenced inside of lambdas, create bindparams + # ahead of time for literal values. If bindparams are produced, + # then rewrite the function globals and closure as necessary so that + # it refers to the bindparams, then invoke the function + new_closure = {} + new_globals = fn.__globals__.copy() + tracker_collection = [] + check_closure_for_stale = [] + + for name in fn.__code__.co_names: + if name not in new_globals: + continue + + bound_value = _roll_down_to_literal(new_globals[name]) + + if coercions._is_literal(bound_value): + new_globals[name] = bind = PyWrapper(name, bound_value) + tracker_collection.append(_globals_tracker(name, bind)) + + if fn.__closure__: + for closure_index, (fv, cell) in enumerate( + zip(fn.__code__.co_freevars, fn.__closure__) + ): + + bound_value = _roll_down_to_literal(cell.cell_contents) + + if coercions._is_literal(bound_value): + new_closure[fv] = bind = PyWrapper(fv, bound_value) + tracker_collection.append( + _closure_tracker(fv, bind, closure_index) + ) + else: + new_closure[fv] = cell.cell_contents + # for normal cell contents, add them to a list that + # we can compare later when we get new lambdas. if + # any identities have changed, then we will recalculate + # the whole lambda and run it again. + check_closure_for_stale.append( + (closure_index, cell.cell_contents) + ) + + if tracker_collection: + new_fn = _rewrite_code_obj( + fn, + [new_closure[name] for name in fn.__code__.co_freevars], + new_globals, + ) + expr = self._invoke_user_fn(new_fn) + + else: + new_fn = fn + expr = self._invoke_user_fn(new_fn) + tracker_collection = [] + + if self.parent_lambda is None: + if isinstance(expr, collections_abc.Sequence): + expected_expr = [ + coercions.expect( + role, + sub_expr, + apply_propagate_attrs=apply_propagate_attrs, + **coerce_kw + ) + for sub_expr in expr + ] + is_sequence = True + else: + expected_expr = coercions.expect( + role, + expr, + apply_propagate_attrs=apply_propagate_attrs, + **coerce_kw + ) + is_sequence = False + else: + expected_expr = expr + is_sequence = False + + if apply_propagate_attrs is not None: + propagate_attrs = apply_propagate_attrs._propagate_attrs + else: + propagate_attrs = util.immutabledict() + + rec = _trackers[self.fn.__code__] = ( + tracker_collection, + check_closure_for_stale, + new_fn, + expected_expr, + is_sequence, + propagate_attrs, + ) + return rec + + +class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): + """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. + + The :class:`_sql.StatementLambdaElement` is constructed using the + :func:`_sql.lambda_stmt` function:: + + + from sqlalchemy import lambda_stmt + + stmt = lambda_stmt(lambda: select(table)) + + Once constructed, additional criteria can be built onto the statement + by adding subsequent lambdas, which accept the existing statement + object as a single parameter:: + + stmt += lambda s: s.where(table.c.col == parameter) + + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`engine_lambda_caching` + + """ + + def __add__(self, other): + return LinkedLambdaElement(other, parent_lambda=self) + + def _execute_on_connection( + self, connection, multiparams, params, execution_options + ): + if self._rec[_EXPR].supports_execution: + return connection._execute_clauseelement( + self, multiparams, params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self) + + @property + def _with_options(self): + return self._rec[_EXPR]._with_options + + @property + def _effective_plugin_target(self): + return self._rec[_EXPR]._effective_plugin_target + + @property + def _is_future(self): + return self._rec[_EXPR]._is_future + + @property + def _execution_options(self): + return self._rec[_EXPR]._execution_options + + +class LinkedLambdaElement(StatementLambdaElement): + def __init__(self, fn, parent_lambda, **kw): + self.fn = fn + self.parent_lambda = parent_lambda + role = None + + apply_propagate_attrs = self + + if fn.__code__ not in _trackers: + rec = self._initialize_var_trackers( + role, apply_propagate_attrs, kw + ) + else: + rec = _trackers[self.fn.__code__] + + closure = fn.__closure__ + + # check if objects referred to by the lambda have changed and + # re-scan the lambda if so. see comments for this same section in + # LambdaElement. + for idx, obj in rec[_STALE_CHECK]: + if closure[idx].cell_contents is not obj: + rec = self._initialize_var_trackers( + role, apply_propagate_attrs, kw + ) + break + + self._rec = rec + + self._propagate_attrs = parent_lambda._propagate_attrs + + self._resolved_bindparams = bindparams = [] + rec = self._rec + while True: + if rec[_TRACKERS]: + for tracker in rec[_TRACKERS]: + tracker(self.fn, bindparams) + if self.parent_lambda is not None: + self = self.parent_lambda + rec = self._rec + else: + break + + def _invoke_user_fn(self, fn, *arg): + return fn(self.parent_lambda._rec[_EXPR]) + + def _gen_cache_key(self, anon_map, bindparams): + if self._resolved_bindparams: + bindparams.extend(self._resolved_bindparams) + + cache_key = (self.fn.__code__, self.__class__) + + parent = self.parent_lambda + while parent is not None: + cache_key = (parent.fn.__code__,) + cache_key + parent = parent.parent_lambda + + return cache_key + + +class PyWrapper(ColumnOperators): + def __init__(self, name, to_evaluate, getter=None): + self._name = name + self._to_evaluate = to_evaluate + self._param = None + self._bind_paths = {} + self._getter = getter + + def __call__(self, *arg, **kw): + elem = object.__getattribute__(self, "_to_evaluate") + value = elem(*arg, **kw) + if coercions._is_literal(value) and not isinstance( + # TODO: coverage where an ORM option or similar is here + value, + traversals.HasCacheKey, + ): + # TODO: we can instead scan the arguments and make sure they + # are all Python literals + + # TODO: coverage + name = object.__getattribute__(self, "_name") + raise exc.InvalidRequestError( + "Can't invoke Python callable %s() inside of lambda " + "expression argument; lambda cache keys should not call " + "regular functions since the caching " + "system does not track the values of the arguments passed " + "to the functions. Call the function outside of the lambda " + "and assign to a local variable that is used in the lambda." + % (name) + ) + else: + return value + + def operate(self, op, *other, **kwargs): + elem = object.__getattribute__(self, "__clause_element__")() + return op(elem, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + elem = object.__getattribute__(self, "__clause_element__")() + return op(other, elem, **kwargs) + + def _extract_bound_parameters(self, starting_point, result_list): + param = object.__getattribute__(self, "_param") + if param is not None: + param = param._with_value(starting_point, maintain_key=True) + result_list.append(param) + for pywrapper in object.__getattribute__(self, "_bind_paths").values(): + getter = object.__getattribute__(pywrapper, "_getter") + element = getter(starting_point) + pywrapper._sa__extract_bound_parameters(element, result_list) + + def __clause_element__(self): + param = object.__getattribute__(self, "_param") + to_evaluate = object.__getattribute__(self, "_to_evaluate") + if param is None: + name = object.__getattribute__(self, "_name") + self._param = param = elements.BindParameter(name, unique=True) + param.type = type_api._resolve_value_to_type(to_evaluate) + + return param._with_value(to_evaluate, maintain_key=True) + + def __getattribute__(self, key): + if key.startswith("_sa_"): + return object.__getattribute__(self, key[4:]) + elif key in ("__clause_element__", "operate", "reverse_operate"): + return object.__getattribute__(self, key) + + if key.startswith("__"): + elem = object.__getattribute__(self, "_to_evaluate") + return getattr(elem, key) + else: + return self._sa__add_getter(key, operator.attrgetter) + + def __getitem__(self, key): + if isinstance(key, PyWrapper): + # TODO: coverage + raise exc.InvalidRequestError( + "Dictionary keys / list indexes inside of a cached " + "lambda must be Python literals only" + ) + return self._sa__add_getter(key, operator.itemgetter) + + def _add_getter(self, key, getter_fn): + + bind_paths = object.__getattribute__(self, "_bind_paths") + + bind_path_key = (key, getter_fn) + if bind_path_key in bind_paths: + return bind_paths[bind_path_key] + + getter = getter_fn(key) + elem = object.__getattribute__(self, "_to_evaluate") + value = getter(elem) + + if coercions._is_literal(value): + wrapper = PyWrapper(key, value, getter) + bind_paths[bind_path_key] = wrapper + return wrapper + else: + return value + + +def _roll_down_to_literal(element): + is_clause_element = hasattr(element, "__clause_element__") + + if is_clause_element: + while not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + try: + element = element.__clause_element__() + except AttributeError: + break + + if not is_clause_element: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + try: + return insp.__clause_element__() + except AttributeError: + return insp + + # TODO: should we coerce consts None/True/False here? + return element + else: + return element + + +def _globals_tracker(name, wrapper): + def extract_parameter_value(current_fn, result): + object.__getattribute__(wrapper, "_extract_bound_parameters")( + current_fn.__globals__[name], result + ) + + return extract_parameter_value + + +def _closure_tracker(name, wrapper, closure_index): + def extract_parameter_value(current_fn, result): + object.__getattribute__(wrapper, "_extract_bound_parameters")( + current_fn.__closure__[closure_index].cell_contents, result + ) + + return extract_parameter_value + + +def _rewrite_code_obj(f, cell_values, globals_): + """Return a copy of f, with a new closure and new globals + + yes it works in pypy :P + + """ + + argrange = range(len(cell_values)) + + code = "def make_cells():\n" + if cell_values: + code += " (%s) = (%s)\n" % ( + ", ".join("i%d" % i for i in argrange), + ", ".join("o%d" % i for i in argrange), + ) + code += " def closure():\n" + code += " return %s\n" % ", ".join("i%d" % i for i in argrange) + code += " return closure.__closure__" + vars_ = {"o%d" % i: cell_values[i] for i in argrange} + exec(code, vars_, vars_) + closure = vars_["make_cells"]() + + func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure) + if sys.version_info >= (3,): + func.__annotations__ = f.__annotations__ + func.__kwdefaults__ = f.__kwdefaults__ + func.__doc__ = f.__doc__ + func.__module__ = f.__module__ + + return func + + +@inspection._inspects(LambdaElement) +def insp(lmb): + return inspection.inspect(lmb._resolved) diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 3d94ec9ff..4205d9f0d 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -19,9 +19,21 @@ class SQLRole(object): """ + allows_lambda = False + uses_inspection = False + class UsesInspection(object): _post_inspect = None + uses_inspection = True + + +class AllowsLambdaRole(object): + allows_lambda = True + + +class HasCacheKeyRole(SQLRole): + _role_name = "Cacheable Core or ORM object" class ColumnArgumentRole(SQLRole): @@ -40,7 +52,7 @@ class TruncatedLabelRole(SQLRole): _role_name = "String SQL identifier" -class ColumnsClauseRole(UsesInspection, ColumnListRole): +class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): _role_name = "Column expression or FROM clause" @property @@ -56,7 +68,7 @@ class ByOfRole(ColumnListRole): _role_name = "GROUP BY / OF / etc. expression" -class GroupByRole(UsesInspection, ByOfRole): +class GroupByRole(AllowsLambdaRole, UsesInspection, ByOfRole): # note there's a special case right now where you can pass a whole # ORM entity to group_by() and it splits out. we may not want to keep # this around @@ -64,7 +76,7 @@ class GroupByRole(UsesInspection, ByOfRole): _role_name = "GROUP BY expression" -class OrderByRole(ByOfRole): +class OrderByRole(AllowsLambdaRole, ByOfRole): _role_name = "ORDER BY expression" @@ -76,7 +88,11 @@ class StatementOptionRole(StructuralRole): _role_name = "statement sub-expression element" -class WhereHavingRole(StructuralRole): +class OnClauseRole(AllowsLambdaRole, StructuralRole): + _role_name = "SQL expression for ON clause" + + +class WhereHavingRole(OnClauseRole): _role_name = "SQL expression for WHERE/HAVING role" @@ -102,7 +118,7 @@ class InElementRole(SQLRole): ) -class JoinTargetRole(UsesInspection, StructuralRole): +class JoinTargetRole(AllowsLambdaRole, UsesInspection, StructuralRole): _role_name = ( "Join target, typically a FROM expression, or ORM " "relationship attribute" @@ -176,7 +192,7 @@ class HasCTERole(ReturnsRowsRole): pass -class CompoundElementRole(SQLRole): +class CompoundElementRole(AllowsLambdaRole, SQLRole): """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc.""" _role_name = ( diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 59c292a07..832da1a57 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -847,7 +847,7 @@ class Join(roles.DMLTableRole, FromClause): # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba # not merged yet self.onclause = coercions.expect( - roles.WhereHavingRole, onclause + roles.OnClauseRole, onclause ).self_group(against=operators._asbool) self.isouter = isouter diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 8d01b7ff7..f41480a94 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -115,45 +115,37 @@ class HasCacheKey(object): in the structures that would affect the SQL string or the type handlers should result in a different cache key. - If a structure cannot produce a useful cache key, it should raise - NotImplementedError, which will result in the entire structure - for which it's part of not being useful as a cache key. - + If a structure cannot produce a useful cache key, the NO_CACHE + symbol should be added to the anon_map and the method should + return None. """ - elements = util.preloaded.sql_elements - idself = id(self) + cls = self.__class__ - if anon_map is not None: - if idself in anon_map: - return (anon_map[idself], self.__class__) - else: - # inline of - # id_ = anon_map[idself] - anon_map[idself] = id_ = str(anon_map.index) - anon_map.index += 1 + if idself in anon_map: + return (anon_map[idself], cls) else: - id_ = None + # inline of + # id_ = anon_map[idself] + anon_map[idself] = id_ = str(anon_map.index) + anon_map.index += 1 try: - dispatcher = self.__class__.__dict__[ - "_generated_cache_key_traversal" - ] + dispatcher = cls.__dict__["_generated_cache_key_traversal"] except KeyError: # most of the dispatchers are generated up front # in sqlalchemy/sql/__init__.py -> # traversals.py-> _preconfigure_traversals(). # this block will generate any remaining dispatchers. - dispatcher = self.__class__._generate_cache_attrs() + dispatcher = cls._generate_cache_attrs() if dispatcher is NO_CACHE: - if anon_map is not None: - anon_map[NO_CACHE] = True + anon_map[NO_CACHE] = True return None - result = (id_, self.__class__) + result = (id_, cls) # inline of _cache_key_traversal_visitor.run_generated_dispatch() @@ -163,15 +155,12 @@ class HasCacheKey(object): if obj is not None: # TODO: see if C code can help here as Python lacks an # efficient switch construct - if meth is CACHE_IN_PLACE: - # cache in place is always going to be a Python - # tuple, dict, list, etc. so we can do a boolean check - if obj: - result += (attrname, obj) - elif meth is STATIC_CACHE_KEY: + + if meth is STATIC_CACHE_KEY: result += (attrname, obj._static_cache_key) elif meth is ANON_NAME: - if elements._anonymous_label in obj.__class__.__mro__: + elements = util.preloaded.sql_elements + if isinstance(obj, elements._anonymous_label): obj = obj.apply_map(anon_map) result += (attrname, obj) elif meth is CALL_GEN_CACHE_KEY: @@ -179,8 +168,14 @@ class HasCacheKey(object): attrname, obj._gen_cache_key(anon_map, bindparams), ) - elif meth is PROPAGATE_ATTRS: - if obj: + + # remaining cache functions are against + # Python tuples, dicts, lists, etc. so we can skip + # if they are empty + elif obj: + if meth is CACHE_IN_PLACE: + result += (attrname, obj) + elif meth is PROPAGATE_ATTRS: result += ( attrname, obj["compile_state_plugin"], @@ -188,16 +183,14 @@ class HasCacheKey(object): anon_map, bindparams ), ) - elif meth is InternalTraversal.dp_annotations_key: - # obj is here is the _annotations dict. however, - # we want to use the memoized cache key version of it. - # for Columns, this should be long lived. For select() - # statements, not so much, but they usually won't have - # annotations. - if obj: + elif meth is InternalTraversal.dp_annotations_key: + # obj is here is the _annotations dict. however, we + # want to use the memoized cache key version of it. for + # Columns, this should be long lived. For select() + # statements, not so much, but they usually won't have + # annotations. result += self._annotations_cache_key - elif meth is InternalTraversal.dp_clauseelement_list: - if obj: + elif meth is InternalTraversal.dp_clauseelement_list: result += ( attrname, tuple( @@ -207,14 +200,7 @@ class HasCacheKey(object): ] ), ) - else: - # note that all the "ClauseElement" standalone cases - # here have been handled by inlines above; so we can - # safely assume the object is a standard list/tuple/dict - # which we can skip if it evaluates to false. - # improvement would be to have this as a flag delivered - # up front in the dispatcher list - if obj: + else: result += meth( attrname, obj, self, anon_map, bindparams ) @@ -384,6 +370,14 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): return "CacheKey(key=%s)" % ("\n".join(output),) + def _generate_param_dict(self): + """used for testing""" + + from .compiler import prefix_anon_map + + _anon_map = prefix_anon_map() + return {b.key % _anon_map: b.effective_value for b in self.bindparams} + def _clone(element, **kw): return element._clone() @@ -506,6 +500,7 @@ class _CacheKey(ExtendedInternalTraversal): ): if not obj: return () + return ( attrname, tuple( |