summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-12-16 17:06:43 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2020-07-03 23:39:51 -0400
commit3dc9a4a2392d033f9d1bd79dd6b6ecea6281a61c (patch)
tree1041bccb37422f526dccb5b1e57ffad1c702549b /lib/sqlalchemy/sql
parent5060043e8e95ab0aab5f63ed288c1426c46da66e (diff)
downloadsqlalchemy-3dc9a4a2392d033f9d1bd79dd6b6ecea6281a61c.tar.gz
introduce deferred lambdas
The coercions system allows us to add in lambdas as arguments to Core and ORM elements without changing them at all. By allowing the lambda to produce a deterministic cache key where we can also cheat and yank out literal parameters means we can move towards having 90% of "baked" functionality in a clearer way right in Core / ORM. As a second step, we can have whole statements inside the lambda, and can then add generation with __add__(), so then we have 100% of "baked" functionality with full support of ad-hoc literal values. Adds some more short_selects tests for the moment for comparison. Other tweaks inside cache key generation as we're trying to approach a certain level of performance such that we can remove the use of "baked" from the loader strategies. As we have not yet closed #4639, however the caching feature has been fully integrated as of b0cfa7379cf8513a821a3dbe3028c4965d9f85bd, we will also add complete caching documentation here and close that issue as well. Closes: #4639 Fixes: #5380 Change-Id: If91f61527236fd4d7ae3cad1f24c38be921c90ba
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/__init__.py7
-rw-r--r--lib/sqlalchemy/sql/base.py27
-rw-r--r--lib/sqlalchemy/sql/coercions.py86
-rw-r--r--lib/sqlalchemy/sql/compiler.py27
-rw-r--r--lib/sqlalchemy/sql/elements.py14
-rw-r--r--lib/sqlalchemy/sql/expression.py6
-rw-r--r--lib/sqlalchemy/sql/functions.py4
-rw-r--r--lib/sqlalchemy/sql/lambdas.py607
-rw-r--r--lib/sqlalchemy/sql/roles.py28
-rw-r--r--lib/sqlalchemy/sql/selectable.py2
-rw-r--r--lib/sqlalchemy/sql/traversals.py91
11 files changed, 817 insertions, 82 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index a25c1b083..2fe6f35d2 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -46,6 +46,8 @@ from .expression import intersect_all # noqa
from .expression import Join # noqa
from .expression import join # noqa
from .expression import label # noqa
+from .expression import lambda_stmt # noqa
+from .expression import LambdaElement # noqa
from .expression import lateral # noqa
from .expression import literal # noqa
from .expression import literal_column # noqa
@@ -62,6 +64,7 @@ from .expression import quoted_name # noqa
from .expression import Select # noqa
from .expression import select # noqa
from .expression import Selectable # noqa
+from .expression import StatementLambdaElement # noqa
from .expression import Subquery # noqa
from .expression import subquery # noqa
from .expression import table # noqa
@@ -106,18 +109,22 @@ def __go(lcls):
from . import coercions
from . import elements
from . import events # noqa
+ from . import lambdas
from . import selectable
from . import schema
from . import sqltypes
+ from . import traversals
from . import type_api
base.coercions = elements.coercions = coercions
base.elements = elements
base.type_api = type_api
coercions.elements = elements
+ coercions.lambdas = lambdas
coercions.schema = schema
coercions.selectable = selectable
coercions.sqltypes = sqltypes
+ coercions.traversals = traversals
_prepare_annotations(ColumnElement, AnnotatedColumnElement)
_prepare_annotations(FromClause, AnnotatedFromClause)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 9dcd7dca9..6cdab8eac 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -14,6 +14,7 @@ import itertools
import operator
import re
+from . import roles
from .traversals import HasCacheKey # noqa
from .traversals import MemoizedHasCacheKey # noqa
from .visitors import ClauseVisitor
@@ -447,13 +448,17 @@ class CompileState(object):
"compile_state_plugin", "default"
)
klass = cls.plugins.get(
- (plugin_name, statement.__visit_name__), None
+ (plugin_name, statement._effective_plugin_target), None
)
if klass is None:
- klass = cls.plugins[("default", statement.__visit_name__)]
+ klass = cls.plugins[
+ ("default", statement._effective_plugin_target)
+ ]
else:
- klass = cls.plugins[("default", statement.__visit_name__)]
+ klass = cls.plugins[
+ ("default", statement._effective_plugin_target)
+ ]
if klass is cls:
return cls(statement, compiler, **kw)
@@ -469,14 +474,18 @@ class CompileState(object):
"compile_state_plugin", "default"
)
try:
- return cls.plugins[(plugin_name, statement.__visit_name__)]
+ return cls.plugins[
+ (plugin_name, statement._effective_plugin_target)
+ ]
except KeyError:
return None
@classmethod
def _get_plugin_class_for_plugin(cls, statement, plugin_name):
try:
- return cls.plugins[(plugin_name, statement.__visit_name__)]
+ return cls.plugins[
+ (plugin_name, statement._effective_plugin_target)
+ ]
except KeyError:
return None
@@ -637,6 +646,10 @@ class Executable(Generative):
("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
]
+ @property
+ def _effective_plugin_target(self):
+ return self.__visit_name__
+
@_generative
def options(self, *options):
"""Apply options to this statement.
@@ -667,7 +680,9 @@ class Executable(Generative):
to the usage of ORM queries
"""
- self._with_options += options
+ self._with_options += tuple(
+ coercions.expect(roles.HasCacheKeyRole, opt) for opt in options
+ )
@_generative
def _add_context_option(self, callable_, cache_args):
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 4c6a0317a..be412c770 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -21,9 +21,11 @@ if util.TYPE_CHECKING:
from types import ModuleType
elements = None # type: ModuleType
+lambdas = None # type: ModuleType
schema = None # type: ModuleType
selectable = None # type: ModuleType
sqltypes = None # type: ModuleType
+traversals = None # type: ModuleType
def _is_literal(element):
@@ -51,6 +53,23 @@ def _document_text_coercion(paramname, meth_rst, param_rst):
def expect(role, element, apply_propagate_attrs=None, argname=None, **kw):
+ if (
+ role.allows_lambda
+ # note callable() will not invoke a __getattr__() method, whereas
+ # hasattr(obj, "__call__") will. by keeping the callable() check here
+ # we prevent most needless calls to hasattr() and therefore
+ # __getattr__(), which is present on ColumnElement.
+ and callable(element)
+ and hasattr(element, "__code__")
+ ):
+ return lambdas.LambdaElement(
+ element,
+ role,
+ apply_propagate_attrs=apply_propagate_attrs,
+ argname=argname,
+ **kw
+ )
+
# major case is that we are given a ClauseElement already, skip more
# elaborate logic up front if possible
impl = _impl_lookup[role]
@@ -106,7 +125,12 @@ def expect(role, element, apply_propagate_attrs=None, argname=None, **kw):
if impl._role_class in resolved.__class__.__mro__:
if impl._post_coercion:
- resolved = impl._post_coercion(resolved, argname=argname, **kw)
+ resolved = impl._post_coercion(
+ resolved,
+ argname=argname,
+ original_element=original_element,
+ **kw
+ )
return resolved
else:
return impl._implicit_coercions(
@@ -230,6 +254,8 @@ class _ColumnCoercions(object):
):
self._warn_for_scalar_subquery_coercion()
return resolved.element.scalar_subquery()
+ elif self._role_class.allows_lambda and resolved._is_lambda_element:
+ return resolved
else:
self._raise_for_expected(original_element, argname, resolved)
@@ -319,6 +345,21 @@ class _SelectIsNotFrom(object):
)
+class HasCacheKeyImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, traversals.HasCacheKey):
+ return original_element
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, **kw):
+ return element
+
+
class ExpressionElementImpl(_ColumnCoercions, RoleImpl):
__slots__ = ()
@@ -420,7 +461,14 @@ class InElementImpl(RoleImpl):
assert not len(element.clauses) == 0
return element.self_group(against=operator)
- elif isinstance(element, elements.BindParameter) and element.expanding:
+ elif isinstance(element, elements.BindParameter):
+ if not element.expanding:
+ # coercing to expanding at the moment to work with the
+ # lambda system. not sure if this is the right approach.
+ # is there a valid use case to send a single non-expanding
+ # param to IN? check for ARRAY type?
+ element = element._clone(maintain_key=True)
+ element.expanding = True
if isinstance(expr, elements.Tuple):
element = element._with_expanding_in_types(
[elem.type for elem in expr]
@@ -431,6 +479,22 @@ class InElementImpl(RoleImpl):
return element
+class OnClauseImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _post_coercion(self, resolved, original_element=None, **kw):
+ # this is a hack right now as we want to use coercion on an
+ # ORM InstrumentedAttribute, but we want to return the object
+ # itself if it is one, not its clause element.
+ # ORM context _join and _legacy_join() would need to be improved
+ # to look for annotations in a clause element form.
+ if isinstance(original_element, roles.JoinTargetRole):
+ return original_element
+ return resolved
+
+
class WhereHavingImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl):
__slots__ = ()
@@ -635,6 +699,24 @@ class StatementImpl(_NoTextCoercion, RoleImpl):
class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl):
__slots__ = ()
+ def _literal_coercion(self, element, **kw):
+ if callable(element) and hasattr(element, "__code__"):
+ return lambdas.StatementLambdaElement(element, self._role_class)
+ else:
+ return super(CoerceTextStatementImpl, self)._literal_coercion(
+ element, **kw
+ )
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_lambda_element:
+ return resolved
+ else:
+ return super(CoerceTextStatementImpl, self)._implicit_coercions(
+ original_element, resolved, argname=argname, **kw
+ )
+
def _text_coercion(self, element, argname=None):
# TODO: this should emit deprecation warning,
# see deprecation warning in engine/base.py execute()
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 6152a28e7..3a3ce5c45 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1296,6 +1296,10 @@ class SQLCompiler(Compiled):
"Cannot compile Column object until " "its 'name' is assigned."
)
+ def visit_lambda_element(self, element, **kw):
+ sql_element = element._resolved
+ return self.process(sql_element, **kw)
+
def visit_column(
self,
column,
@@ -1624,7 +1628,7 @@ class SQLCompiler(Compiled):
return func.clause_expr._compiler_dispatch(self, **kwargs)
def visit_compound_select(
- self, cs, asfrom=False, compound_index=0, **kwargs
+ self, cs, asfrom=False, compound_index=None, **kwargs
):
toplevel = not self.stack
@@ -1635,10 +1639,14 @@ class SQLCompiler(Compiled):
entry = self._default_stack_entry if toplevel else self.stack[-1]
need_result_map = toplevel or (
- compound_index == 0
+ not compound_index
and entry.get("need_result_map_for_compound", False)
)
+ # indicates there is already a CompoundSelect in play
+ if compound_index == 0:
+ entry["select_0"] = cs
+
self.stack.append(
{
"correlate_froms": entry["correlate_froms"],
@@ -2654,7 +2662,7 @@ class SQLCompiler(Compiled):
select_stmt,
asfrom=False,
fromhints=None,
- compound_index=0,
+ compound_index=None,
select_wraps_for=None,
lateral=False,
from_linter=None,
@@ -2709,7 +2717,9 @@ class SQLCompiler(Compiled):
or entry.get("need_result_map_for_nested", False)
)
- if compound_index > 0:
+ # indicates there is a CompoundSelect in play and we are not the
+ # first select
+ if compound_index:
populate_result_map = False
# this was first proposed as part of #3372; however, it is not
@@ -2844,11 +2854,10 @@ class SQLCompiler(Compiled):
correlate_froms = entry["correlate_froms"]
asfrom_froms = entry["asfrom_froms"]
- if compound_index > 0:
- # note this is cached
- select_0 = entry["selectable"].selects[0]
- if select_0._is_select_container:
- select_0 = select_0.element
+ if compound_index == 0:
+ entry["select_0"] = select
+ elif compound_index:
+ select_0 = entry["select_0"]
numcols = len(select_0.selected_columns)
if len(compile_state.columns_plus_names) != numcols:
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index af5eab257..6ce505412 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -215,6 +215,7 @@ class ClauseElement(
_is_select_statement = False
_is_bind_parameter = False
_is_clause_list = False
+ _is_lambda_element = False
_order_by_label_element = None
@@ -1337,9 +1338,6 @@ class BindParameter(roles.InElementRole, ColumnElement):
:ref:`change_4808`.
-
-
-
"""
if required is NO_ARG:
@@ -1406,15 +1404,15 @@ class BindParameter(roles.InElementRole, ColumnElement):
the context of an expanding IN against a tuple.
"""
- cloned = self._clone()
+ cloned = self._clone(maintain_key=True)
cloned._expanding_in_types = types
return cloned
- def _with_value(self, value):
+ def _with_value(self, value, maintain_key=False):
"""Return a copy of this :class:`.BindParameter` with the given value
set.
"""
- cloned = self._clone()
+ cloned = self._clone(maintain_key=maintain_key)
cloned.value = value
cloned.callable = None
cloned.required = False
@@ -1442,9 +1440,9 @@ class BindParameter(roles.InElementRole, ColumnElement):
c.type = type_
return c
- def _clone(self):
+ def _clone(self, maintain_key=False):
c = ClauseElement._clone(self)
- if self.unique:
+ if not maintain_key and self.unique:
c.key = _anonymous_label(
"%%(%d %s)s" % (id(c), c._orig_key or "param")
)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index e25063372..37441a125 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -29,6 +29,8 @@ __all__ = [
"Insert",
"Join",
"Lateral",
+ "LambdaElement",
+ "StatementLambdaElement",
"Select",
"Selectable",
"TableClause",
@@ -59,6 +61,7 @@ __all__ = [
"join",
"label",
"lateral",
+ "lambda_stmt",
"literal",
"literal_column",
"not_",
@@ -135,6 +138,9 @@ from .functions import func # noqa
from .functions import Function # noqa
from .functions import FunctionElement # noqa
from .functions import modifier # noqa
+from .lambdas import lambda_stmt # noqa
+from .lambdas import LambdaElement # noqa
+from .lambdas import StatementLambdaElement # noqa
from .selectable import Alias # noqa
from .selectable import AliasedReturnsRows # noqa
from .selectable import CompoundSelect # noqa
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 6fff26842..c1b8bbd27 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -614,7 +614,7 @@ class Function(FunctionElement):
new :class:`.Function` instances.
"""
- self.packagenames = kw.pop("packagenames", None) or []
+ self.packagenames = kw.pop("packagenames", None) or ()
self.name = name
self._bind = kw.get("bind", None)
self.type = sqltypes.to_instance(kw.get("type_", None))
@@ -759,7 +759,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
for c in args
]
self._has_args = self._has_args or bool(parsed_args)
- self.packagenames = []
+ self.packagenames = ()
self._bind = kwargs.get("bind", None)
self.clause_expr = ClauseList(
operator=operators.comma_op, group_contents=True, *parsed_args
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
new file mode 100644
index 000000000..792411189
--- /dev/null
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -0,0 +1,607 @@
+# sql/lambdas.py
+# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import itertools
+import operator
+import sys
+import weakref
+
+from . import coercions
+from . import elements
+from . import roles
+from . import schema
+from . import traversals
+from . import type_api
+from . import visitors
+from .operators import ColumnOperators
+from .. import exc
+from .. import inspection
+from .. import util
+from ..util import collections_abc
+
+_trackers = weakref.WeakKeyDictionary()
+
+
+_TRACKERS = 0
+_STALE_CHECK = 1
+_REAL_FN = 2
+_EXPR = 3
+_IS_SEQUENCE = 4
+_PROPAGATE_ATTRS = 5
+
+
+def lambda_stmt(lmb):
+ """Produce a SQL statement that is cached as a lambda.
+
+ This SQL statement will only be constructed if element has not been
+ compiled yet. The approach is used to save on Python function overhead
+ when constructing statements that will be cached.
+
+ E.g.::
+
+ from sqlalchemy import lambda_stmt
+
+ stmt = lambda_stmt(lambda: table.select())
+ stmt += lambda s: s.where(table.c.id == 5)
+
+ result = connection.execute(stmt)
+
+ The object returned is an instance of :class:`_sql.StatementLambdaElement`.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+
+ """
+ return coercions.expect(roles.CoerceTextStatementRole, lmb)
+
+
+class LambdaElement(elements.ClauseElement):
+ """A SQL construct where the state is stored as an un-invoked lambda.
+
+ The :class:`_sql.LambdaElement` is produced transparently whenever
+ passing lambda expressions into SQL constructs, such as::
+
+ stmt = select(table).where(lambda: table.c.col == parameter)
+
+ The :class:`_sql.LambdaElement` is the base of the
+ :class:`_sql.StatementLambdaElement` which represents a full statement
+ within a lambda.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+ """
+
+ __visit_name__ = "lambda_element"
+
+ _is_lambda_element = True
+
+ _resolved_bindparams = ()
+
+ _traverse_internals = [
+ ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+ ]
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
+
+ def __init__(self, fn, role, apply_propagate_attrs=None, **kw):
+ self.fn = fn
+ self.role = role
+ self.parent_lambda = None
+
+ if apply_propagate_attrs is None and (
+ role is roles.CoerceTextStatementRole
+ ):
+ apply_propagate_attrs = self
+
+ if fn.__code__ not in _trackers:
+ rec = self._initialize_var_trackers(
+ role, apply_propagate_attrs, kw
+ )
+ else:
+ rec = _trackers[self.fn.__code__]
+ closure = fn.__closure__
+
+ # check if the objects fixed inside the lambda that we've cached
+ # have been changed. This can apply to things like mappers that
+ # were recreated in test suites. if so, re-initialize.
+ #
+ # this is a small performance hit on every use for a not very
+ # common situation, however it's very hard to debug if the
+ # condition does occur.
+ for idx, obj in rec[_STALE_CHECK]:
+ if closure[idx].cell_contents is not obj:
+ rec = self._initialize_var_trackers(
+ role, apply_propagate_attrs, kw
+ )
+ break
+ self._rec = rec
+
+ if apply_propagate_attrs is not None:
+ propagate_attrs = rec[_PROPAGATE_ATTRS]
+ if propagate_attrs:
+ apply_propagate_attrs._propagate_attrs = propagate_attrs
+
+ if rec[_TRACKERS]:
+ self._resolved_bindparams = bindparams = []
+ for tracker in rec[_TRACKERS]:
+ tracker(self.fn, bindparams)
+
+ def __getattr__(self, key):
+ return getattr(self._rec[_EXPR], key)
+
+ @property
+ def _is_sequence(self):
+ return self._rec[_IS_SEQUENCE]
+
+ @property
+ def _select_iterable(self):
+ if self._is_sequence:
+ return itertools.chain.from_iterable(
+ [element._select_iterable for element in self._resolved]
+ )
+
+ else:
+ return self._resolved._select_iterable
+
+ @property
+ def _from_objects(self):
+ if self._is_sequence:
+ return itertools.chain.from_iterable(
+ [element._from_objects for element in self._resolved]
+ )
+
+ else:
+ return self._resolved._from_objects
+
+ def _param_dict(self):
+ return {b.key: b.value for b in self._resolved_bindparams}
+
+ @util.memoized_property
+ def _resolved(self):
+ bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
+
+ def replace(thing):
+ if (
+ isinstance(thing, elements.BindParameter)
+ and thing.key in bindparam_lookup
+ ):
+ bind = bindparam_lookup[thing.key]
+ # TODO: consider
+ # if we should clone the bindparam here, re-cache the new
+ # version, etc. also we make an assumption about "expanding"
+ # in this case.
+ if thing.expanding:
+ bind.expanding = True
+ return bind
+
+ expr = self._rec[_EXPR]
+
+ if self._rec[_IS_SEQUENCE]:
+ expr = [
+ visitors.replacement_traverse(sub_expr, {}, replace)
+ for sub_expr in expr
+ ]
+ elif getattr(expr, "is_clause_element", False):
+ expr = visitors.replacement_traverse(expr, {}, replace)
+
+ return expr
+
+ def _gen_cache_key(self, anon_map, bindparams):
+
+ cache_key = (self.fn.__code__, self.__class__)
+
+ if self._resolved_bindparams:
+ bindparams.extend(self._resolved_bindparams)
+
+ return cache_key
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn()
+
+ def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw):
+ fn = self.fn
+
+ # track objects referenced inside of lambdas, create bindparams
+ # ahead of time for literal values. If bindparams are produced,
+ # then rewrite the function globals and closure as necessary so that
+ # it refers to the bindparams, then invoke the function
+ new_closure = {}
+ new_globals = fn.__globals__.copy()
+ tracker_collection = []
+ check_closure_for_stale = []
+
+ for name in fn.__code__.co_names:
+ if name not in new_globals:
+ continue
+
+ bound_value = _roll_down_to_literal(new_globals[name])
+
+ if coercions._is_literal(bound_value):
+ new_globals[name] = bind = PyWrapper(name, bound_value)
+ tracker_collection.append(_globals_tracker(name, bind))
+
+ if fn.__closure__:
+ for closure_index, (fv, cell) in enumerate(
+ zip(fn.__code__.co_freevars, fn.__closure__)
+ ):
+
+ bound_value = _roll_down_to_literal(cell.cell_contents)
+
+ if coercions._is_literal(bound_value):
+ new_closure[fv] = bind = PyWrapper(fv, bound_value)
+ tracker_collection.append(
+ _closure_tracker(fv, bind, closure_index)
+ )
+ else:
+ new_closure[fv] = cell.cell_contents
+ # for normal cell contents, add them to a list that
+ # we can compare later when we get new lambdas. if
+ # any identities have changed, then we will recalculate
+ # the whole lambda and run it again.
+ check_closure_for_stale.append(
+ (closure_index, cell.cell_contents)
+ )
+
+ if tracker_collection:
+ new_fn = _rewrite_code_obj(
+ fn,
+ [new_closure[name] for name in fn.__code__.co_freevars],
+ new_globals,
+ )
+ expr = self._invoke_user_fn(new_fn)
+
+ else:
+ new_fn = fn
+ expr = self._invoke_user_fn(new_fn)
+ tracker_collection = []
+
+ if self.parent_lambda is None:
+ if isinstance(expr, collections_abc.Sequence):
+ expected_expr = [
+ coercions.expect(
+ role,
+ sub_expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ **coerce_kw
+ )
+ for sub_expr in expr
+ ]
+ is_sequence = True
+ else:
+ expected_expr = coercions.expect(
+ role,
+ expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ **coerce_kw
+ )
+ is_sequence = False
+ else:
+ expected_expr = expr
+ is_sequence = False
+
+ if apply_propagate_attrs is not None:
+ propagate_attrs = apply_propagate_attrs._propagate_attrs
+ else:
+ propagate_attrs = util.immutabledict()
+
+ rec = _trackers[self.fn.__code__] = (
+ tracker_collection,
+ check_closure_for_stale,
+ new_fn,
+ expected_expr,
+ is_sequence,
+ propagate_attrs,
+ )
+ return rec
+
+
+class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
+ """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
+
+ The :class:`_sql.StatementLambdaElement` is constructed using the
+ :func:`_sql.lambda_stmt` function::
+
+
+ from sqlalchemy import lambda_stmt
+
+ stmt = lambda_stmt(lambda: select(table))
+
+ Once constructed, additional criteria can be built onto the statement
+ by adding subsequent lambdas, which accept the existing statement
+ object as a single parameter::
+
+ stmt += lambda s: s.where(table.c.col == parameter)
+
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+ """
+
+ def __add__(self, other):
+ return LinkedLambdaElement(other, parent_lambda=self)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self._rec[_EXPR].supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
+
+ @property
+ def _with_options(self):
+ return self._rec[_EXPR]._with_options
+
+ @property
+ def _effective_plugin_target(self):
+ return self._rec[_EXPR]._effective_plugin_target
+
+ @property
+ def _is_future(self):
+ return self._rec[_EXPR]._is_future
+
+ @property
+ def _execution_options(self):
+ return self._rec[_EXPR]._execution_options
+
+
+class LinkedLambdaElement(StatementLambdaElement):
+ def __init__(self, fn, parent_lambda, **kw):
+ self.fn = fn
+ self.parent_lambda = parent_lambda
+ role = None
+
+ apply_propagate_attrs = self
+
+ if fn.__code__ not in _trackers:
+ rec = self._initialize_var_trackers(
+ role, apply_propagate_attrs, kw
+ )
+ else:
+ rec = _trackers[self.fn.__code__]
+
+ closure = fn.__closure__
+
+ # check if objects referred to by the lambda have changed and
+ # re-scan the lambda if so. see comments for this same section in
+ # LambdaElement.
+ for idx, obj in rec[_STALE_CHECK]:
+ if closure[idx].cell_contents is not obj:
+ rec = self._initialize_var_trackers(
+ role, apply_propagate_attrs, kw
+ )
+ break
+
+ self._rec = rec
+
+ self._propagate_attrs = parent_lambda._propagate_attrs
+
+ self._resolved_bindparams = bindparams = []
+ rec = self._rec
+ while True:
+ if rec[_TRACKERS]:
+ for tracker in rec[_TRACKERS]:
+ tracker(self.fn, bindparams)
+ if self.parent_lambda is not None:
+ self = self.parent_lambda
+ rec = self._rec
+ else:
+ break
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(self.parent_lambda._rec[_EXPR])
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ if self._resolved_bindparams:
+ bindparams.extend(self._resolved_bindparams)
+
+ cache_key = (self.fn.__code__, self.__class__)
+
+ parent = self.parent_lambda
+ while parent is not None:
+ cache_key = (parent.fn.__code__,) + cache_key
+ parent = parent.parent_lambda
+
+ return cache_key
+
+
+class PyWrapper(ColumnOperators):
+ def __init__(self, name, to_evaluate, getter=None):
+ self._name = name
+ self._to_evaluate = to_evaluate
+ self._param = None
+ self._bind_paths = {}
+ self._getter = getter
+
+ def __call__(self, *arg, **kw):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ value = elem(*arg, **kw)
+ if coercions._is_literal(value) and not isinstance(
+ # TODO: coverage where an ORM option or similar is here
+ value,
+ traversals.HasCacheKey,
+ ):
+ # TODO: we can instead scan the arguments and make sure they
+ # are all Python literals
+
+ # TODO: coverage
+ name = object.__getattribute__(self, "_name")
+ raise exc.InvalidRequestError(
+ "Can't invoke Python callable %s() inside of lambda "
+ "expression argument; lambda cache keys should not call "
+ "regular functions since the caching "
+ "system does not track the values of the arguments passed "
+ "to the functions. Call the function outside of the lambda "
+ "and assign to a local variable that is used in the lambda."
+ % (name)
+ )
+ else:
+ return value
+
+ def operate(self, op, *other, **kwargs):
+ elem = object.__getattribute__(self, "__clause_element__")()
+ return op(elem, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ elem = object.__getattribute__(self, "__clause_element__")()
+ return op(other, elem, **kwargs)
+
+ def _extract_bound_parameters(self, starting_point, result_list):
+ param = object.__getattribute__(self, "_param")
+ if param is not None:
+ param = param._with_value(starting_point, maintain_key=True)
+ result_list.append(param)
+ for pywrapper in object.__getattribute__(self, "_bind_paths").values():
+ getter = object.__getattribute__(pywrapper, "_getter")
+ element = getter(starting_point)
+ pywrapper._sa__extract_bound_parameters(element, result_list)
+
+ def __clause_element__(self):
+ param = object.__getattribute__(self, "_param")
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ if param is None:
+ name = object.__getattribute__(self, "_name")
+ self._param = param = elements.BindParameter(name, unique=True)
+ param.type = type_api._resolve_value_to_type(to_evaluate)
+
+ return param._with_value(to_evaluate, maintain_key=True)
+
+ def __getattribute__(self, key):
+ if key.startswith("_sa_"):
+ return object.__getattribute__(self, key[4:])
+ elif key in ("__clause_element__", "operate", "reverse_operate"):
+ return object.__getattribute__(self, key)
+
+ if key.startswith("__"):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ return getattr(elem, key)
+ else:
+ return self._sa__add_getter(key, operator.attrgetter)
+
+ def __getitem__(self, key):
+ if isinstance(key, PyWrapper):
+ # TODO: coverage
+ raise exc.InvalidRequestError(
+ "Dictionary keys / list indexes inside of a cached "
+ "lambda must be Python literals only"
+ )
+ return self._sa__add_getter(key, operator.itemgetter)
+
+ def _add_getter(self, key, getter_fn):
+
+ bind_paths = object.__getattribute__(self, "_bind_paths")
+
+ bind_path_key = (key, getter_fn)
+ if bind_path_key in bind_paths:
+ return bind_paths[bind_path_key]
+
+ getter = getter_fn(key)
+ elem = object.__getattribute__(self, "_to_evaluate")
+ value = getter(elem)
+
+ if coercions._is_literal(value):
+ wrapper = PyWrapper(key, value, getter)
+ bind_paths[bind_path_key] = wrapper
+ return wrapper
+ else:
+ return value
+
+
+def _roll_down_to_literal(element):
+ is_clause_element = hasattr(element, "__clause_element__")
+
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
+
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ try:
+ return insp.__clause_element__()
+ except AttributeError:
+ return insp
+
+ # TODO: should we coerce consts None/True/False here?
+ return element
+ else:
+ return element
+
+
+def _globals_tracker(name, wrapper):
+ def extract_parameter_value(current_fn, result):
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__globals__[name], result
+ )
+
+ return extract_parameter_value
+
+
+def _closure_tracker(name, wrapper, closure_index):
+ def extract_parameter_value(current_fn, result):
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__closure__[closure_index].cell_contents, result
+ )
+
+ return extract_parameter_value
+
+
+def _rewrite_code_obj(f, cell_values, globals_):
+ """Return a copy of f, with a new closure and new globals
+
+ yes it works in pypy :P
+
+ """
+
+ argrange = range(len(cell_values))
+
+ code = "def make_cells():\n"
+ if cell_values:
+ code += " (%s) = (%s)\n" % (
+ ", ".join("i%d" % i for i in argrange),
+ ", ".join("o%d" % i for i in argrange),
+ )
+ code += " def closure():\n"
+ code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
+ code += " return closure.__closure__"
+ vars_ = {"o%d" % i: cell_values[i] for i in argrange}
+ exec(code, vars_, vars_)
+ closure = vars_["make_cells"]()
+
+ func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure)
+ if sys.version_info >= (3,):
+ func.__annotations__ = f.__annotations__
+ func.__kwdefaults__ = f.__kwdefaults__
+ func.__doc__ = f.__doc__
+ func.__module__ = f.__module__
+
+ return func
+
+
+@inspection._inspects(LambdaElement)
+def insp(lmb):
+ return inspection.inspect(lmb._resolved)
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 3d94ec9ff..4205d9f0d 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -19,9 +19,21 @@ class SQLRole(object):
"""
+ allows_lambda = False
+ uses_inspection = False
+
class UsesInspection(object):
_post_inspect = None
+ uses_inspection = True
+
+
+class AllowsLambdaRole(object):
+ allows_lambda = True
+
+
+class HasCacheKeyRole(SQLRole):
+ _role_name = "Cacheable Core or ORM object"
class ColumnArgumentRole(SQLRole):
@@ -40,7 +52,7 @@ class TruncatedLabelRole(SQLRole):
_role_name = "String SQL identifier"
-class ColumnsClauseRole(UsesInspection, ColumnListRole):
+class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
_role_name = "Column expression or FROM clause"
@property
@@ -56,7 +68,7 @@ class ByOfRole(ColumnListRole):
_role_name = "GROUP BY / OF / etc. expression"
-class GroupByRole(UsesInspection, ByOfRole):
+class GroupByRole(AllowsLambdaRole, UsesInspection, ByOfRole):
# note there's a special case right now where you can pass a whole
# ORM entity to group_by() and it splits out. we may not want to keep
# this around
@@ -64,7 +76,7 @@ class GroupByRole(UsesInspection, ByOfRole):
_role_name = "GROUP BY expression"
-class OrderByRole(ByOfRole):
+class OrderByRole(AllowsLambdaRole, ByOfRole):
_role_name = "ORDER BY expression"
@@ -76,7 +88,11 @@ class StatementOptionRole(StructuralRole):
_role_name = "statement sub-expression element"
-class WhereHavingRole(StructuralRole):
+class OnClauseRole(AllowsLambdaRole, StructuralRole):
+ _role_name = "SQL expression for ON clause"
+
+
+class WhereHavingRole(OnClauseRole):
_role_name = "SQL expression for WHERE/HAVING role"
@@ -102,7 +118,7 @@ class InElementRole(SQLRole):
)
-class JoinTargetRole(UsesInspection, StructuralRole):
+class JoinTargetRole(AllowsLambdaRole, UsesInspection, StructuralRole):
_role_name = (
"Join target, typically a FROM expression, or ORM "
"relationship attribute"
@@ -176,7 +192,7 @@ class HasCTERole(ReturnsRowsRole):
pass
-class CompoundElementRole(SQLRole):
+class CompoundElementRole(AllowsLambdaRole, SQLRole):
"""SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc."""
_role_name = (
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 59c292a07..832da1a57 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -847,7 +847,7 @@ class Join(roles.DMLTableRole, FromClause):
# note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba
# not merged yet
self.onclause = coercions.expect(
- roles.WhereHavingRole, onclause
+ roles.OnClauseRole, onclause
).self_group(against=operators._asbool)
self.isouter = isouter
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 8d01b7ff7..f41480a94 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -115,45 +115,37 @@ class HasCacheKey(object):
in the structures that would affect the SQL string or the type handlers
should result in a different cache key.
- If a structure cannot produce a useful cache key, it should raise
- NotImplementedError, which will result in the entire structure
- for which it's part of not being useful as a cache key.
-
+ If a structure cannot produce a useful cache key, the NO_CACHE
+ symbol should be added to the anon_map and the method should
+ return None.
"""
- elements = util.preloaded.sql_elements
-
idself = id(self)
+ cls = self.__class__
- if anon_map is not None:
- if idself in anon_map:
- return (anon_map[idself], self.__class__)
- else:
- # inline of
- # id_ = anon_map[idself]
- anon_map[idself] = id_ = str(anon_map.index)
- anon_map.index += 1
+ if idself in anon_map:
+ return (anon_map[idself], cls)
else:
- id_ = None
+ # inline of
+ # id_ = anon_map[idself]
+ anon_map[idself] = id_ = str(anon_map.index)
+ anon_map.index += 1
try:
- dispatcher = self.__class__.__dict__[
- "_generated_cache_key_traversal"
- ]
+ dispatcher = cls.__dict__["_generated_cache_key_traversal"]
except KeyError:
# most of the dispatchers are generated up front
# in sqlalchemy/sql/__init__.py ->
# traversals.py-> _preconfigure_traversals().
# this block will generate any remaining dispatchers.
- dispatcher = self.__class__._generate_cache_attrs()
+ dispatcher = cls._generate_cache_attrs()
if dispatcher is NO_CACHE:
- if anon_map is not None:
- anon_map[NO_CACHE] = True
+ anon_map[NO_CACHE] = True
return None
- result = (id_, self.__class__)
+ result = (id_, cls)
# inline of _cache_key_traversal_visitor.run_generated_dispatch()
@@ -163,15 +155,12 @@ class HasCacheKey(object):
if obj is not None:
# TODO: see if C code can help here as Python lacks an
# efficient switch construct
- if meth is CACHE_IN_PLACE:
- # cache in place is always going to be a Python
- # tuple, dict, list, etc. so we can do a boolean check
- if obj:
- result += (attrname, obj)
- elif meth is STATIC_CACHE_KEY:
+
+ if meth is STATIC_CACHE_KEY:
result += (attrname, obj._static_cache_key)
elif meth is ANON_NAME:
- if elements._anonymous_label in obj.__class__.__mro__:
+ elements = util.preloaded.sql_elements
+ if isinstance(obj, elements._anonymous_label):
obj = obj.apply_map(anon_map)
result += (attrname, obj)
elif meth is CALL_GEN_CACHE_KEY:
@@ -179,8 +168,14 @@ class HasCacheKey(object):
attrname,
obj._gen_cache_key(anon_map, bindparams),
)
- elif meth is PROPAGATE_ATTRS:
- if obj:
+
+ # remaining cache functions are against
+ # Python tuples, dicts, lists, etc. so we can skip
+ # if they are empty
+ elif obj:
+ if meth is CACHE_IN_PLACE:
+ result += (attrname, obj)
+ elif meth is PROPAGATE_ATTRS:
result += (
attrname,
obj["compile_state_plugin"],
@@ -188,16 +183,14 @@ class HasCacheKey(object):
anon_map, bindparams
),
)
- elif meth is InternalTraversal.dp_annotations_key:
- # obj is here is the _annotations dict. however,
- # we want to use the memoized cache key version of it.
- # for Columns, this should be long lived. For select()
- # statements, not so much, but they usually won't have
- # annotations.
- if obj:
+ elif meth is InternalTraversal.dp_annotations_key:
+ # obj is here is the _annotations dict. however, we
+ # want to use the memoized cache key version of it. for
+ # Columns, this should be long lived. For select()
+ # statements, not so much, but they usually won't have
+ # annotations.
result += self._annotations_cache_key
- elif meth is InternalTraversal.dp_clauseelement_list:
- if obj:
+ elif meth is InternalTraversal.dp_clauseelement_list:
result += (
attrname,
tuple(
@@ -207,14 +200,7 @@ class HasCacheKey(object):
]
),
)
- else:
- # note that all the "ClauseElement" standalone cases
- # here have been handled by inlines above; so we can
- # safely assume the object is a standard list/tuple/dict
- # which we can skip if it evaluates to false.
- # improvement would be to have this as a flag delivered
- # up front in the dispatcher list
- if obj:
+ else:
result += meth(
attrname, obj, self, anon_map, bindparams
)
@@ -384,6 +370,14 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
return "CacheKey(key=%s)" % ("\n".join(output),)
+ def _generate_param_dict(self):
+ """used for testing"""
+
+ from .compiler import prefix_anon_map
+
+ _anon_map = prefix_anon_map()
+ return {b.key % _anon_map: b.effective_value for b in self.bindparams}
+
def _clone(element, **kw):
return element._clone()
@@ -506,6 +500,7 @@ class _CacheKey(ExtendedInternalTraversal):
):
if not obj:
return ()
+
return (
attrname,
tuple(