summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/lambdas.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/lambdas.py')
-rw-r--r--lib/sqlalchemy/sql/lambdas.py980
1 files changed, 718 insertions, 262 deletions
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index 792411189..327003902 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -8,6 +8,7 @@
import itertools
import operator
import sys
+import types
import weakref
from . import coercions
@@ -17,29 +18,23 @@ from . import schema
from . import traversals
from . import type_api
from . import visitors
+from .base import _clone
from .operators import ColumnOperators
from .. import exc
from .. import inspection
from .. import util
from ..util import collections_abc
-_trackers = weakref.WeakKeyDictionary()
+_closure_per_cache_key = util.LRUCache(1000)
-_TRACKERS = 0
-_STALE_CHECK = 1
-_REAL_FN = 2
-_EXPR = 3
-_IS_SEQUENCE = 4
-_PROPAGATE_ATTRS = 5
-
-
-def lambda_stmt(lmb):
+def lambda_stmt(lmb, **opts):
"""Produce a SQL statement that is cached as a lambda.
- This SQL statement will only be constructed if element has not been
- compiled yet. The approach is used to save on Python function overhead
- when constructing statements that will be cached.
+ The Python code object within the lambda is scanned for both Python
+ literals that will become bound parameters as well as closure variables
+ that refer to Core or ORM constructs that may vary. The lambda itself
+ will be invoked only once per particular set of constructs detected.
E.g.::
@@ -60,7 +55,8 @@ def lambda_stmt(lmb):
"""
- return coercions.expect(roles.CoerceTextStatementRole, lmb)
+
+ return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts)
class LambdaElement(elements.ClauseElement):
@@ -87,64 +83,108 @@ class LambdaElement(elements.ClauseElement):
_is_lambda_element = True
- _resolved_bindparams = ()
-
_traverse_internals = [
("_resolved", visitors.InternalTraversal.dp_clauseelement)
]
+ _transforms = ()
+
+ parent_lambda = None
+
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
def __init__(self, fn, role, apply_propagate_attrs=None, **kw):
self.fn = fn
self.role = role
- self.parent_lambda = None
+ self.tracker_key = (fn.__code__,)
if apply_propagate_attrs is None and (
role is roles.CoerceTextStatementRole
):
apply_propagate_attrs = self
- if fn.__code__ not in _trackers:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw)
+
+ if apply_propagate_attrs is not None:
+ propagate_attrs = rec.propagate_attrs
+ if propagate_attrs:
+ apply_propagate_attrs._propagate_attrs = propagate_attrs
+
+ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw):
+ lambda_cache = kw.get("lambda_cache", _closure_per_cache_key)
+
+ tracker_key = self.tracker_key
+
+ fn = self.fn
+ closure = fn.__closure__
+
+ tracker = AnalyzedCode.get(
+ fn,
+ self,
+ kw,
+ track_bound_values=kw.get("track_bound_values", True),
+ enable_tracking=kw.get("enable_tracking", True),
+ track_on=kw.get("track_on", None),
+ )
+
+ self._resolved_bindparams = bindparams = []
+
+ anon_map = traversals.anon_map()
+ cache_key = tuple(
+ [
+ getter(closure, kw, anon_map, bindparams)
+ for getter in tracker.closure_trackers
+ ]
+ )
+ if self.parent_lambda is not None:
+ cache_key = self.parent_lambda.closure_cache_key + cache_key
+
+ self.closure_cache_key = cache_key
+
+ try:
+ rec = lambda_cache[tracker_key + cache_key]
+ except KeyError:
+ rec = None
+
+ if rec is None:
+ rec = AnalyzedFunction(
+ tracker, self, apply_propagate_attrs, kw, fn
)
+ rec.closure_bindparams = bindparams
+ lambda_cache[tracker_key + cache_key] = rec
else:
- rec = _trackers[self.fn.__code__]
- closure = fn.__closure__
+ bindparams[:] = [
+ orig_bind._with_value(new_bind.value, maintain_key=True)
+ for orig_bind, new_bind in zip(
+ rec.closure_bindparams, bindparams
+ )
+ ]
+
+ if self.parent_lambda is not None:
+ bindparams[:0] = self.parent_lambda._resolved_bindparams
- # check if the objects fixed inside the lambda that we've cached
- # have been changed. This can apply to things like mappers that
- # were recreated in test suites. if so, re-initialize.
- #
- # this is a small performance hit on every use for a not very
- # common situation, however it's very hard to debug if the
- # condition does occur.
- for idx, obj in rec[_STALE_CHECK]:
- if closure[idx].cell_contents is not obj:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
- )
- break
self._rec = rec
- if apply_propagate_attrs is not None:
- propagate_attrs = rec[_PROPAGATE_ATTRS]
- if propagate_attrs:
- apply_propagate_attrs._propagate_attrs = propagate_attrs
+ lambda_element = self
+ while lambda_element is not None:
+ rec = lambda_element._rec
+ if rec.bindparam_trackers:
+ tracker_instrumented_fn = rec.tracker_instrumented_fn
+ for tracker in rec.bindparam_trackers:
+ tracker(
+ lambda_element.fn, tracker_instrumented_fn, bindparams
+ )
+ lambda_element = lambda_element.parent_lambda
- if rec[_TRACKERS]:
- self._resolved_bindparams = bindparams = []
- for tracker in rec[_TRACKERS]:
- tracker(self.fn, bindparams)
+ return rec
def __getattr__(self, key):
- return getattr(self._rec[_EXPR], key)
+ return getattr(self._rec.expected_expr, key)
@property
def _is_sequence(self):
- return self._rec[_IS_SEQUENCE]
+ return self._rec.is_sequence
@property
def _select_iterable(self):
@@ -169,8 +209,7 @@ class LambdaElement(elements.ClauseElement):
def _param_dict(self):
return {b.key: b.value for b in self._resolved_bindparams}
- @util.memoized_property
- def _resolved(self):
+ def _setup_binds_for_tracked_expr(self, expr):
bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
def replace(thing):
@@ -179,17 +218,11 @@ class LambdaElement(elements.ClauseElement):
and thing.key in bindparam_lookup
):
bind = bindparam_lookup[thing.key]
- # TODO: consider
- # if we should clone the bindparam here, re-cache the new
- # version, etc. also we make an assumption about "expanding"
- # in this case.
if thing.expanding:
bind.expanding = True
return bind
- expr = self._rec[_EXPR]
-
- if self._rec[_IS_SEQUENCE]:
+ if self._rec.is_sequence:
expr = [
visitors.replacement_traverse(sub_expr, {}, replace)
for sub_expr in expr
@@ -199,9 +232,39 @@ class LambdaElement(elements.ClauseElement):
return expr
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ # TODO: this needs A LOT of tests
+ self._resolved = clone(
+ self._resolved,
+ deferred_copy_internals=deferred_copy_internals,
+ **kw
+ )
+
+ @util.memoized_property
+ def _resolved(self):
+ expr = self._rec.expected_expr
+
+ if self._resolved_bindparams:
+ expr = self._setup_binds_for_tracked_expr(expr)
+
+ return expr
+
def _gen_cache_key(self, anon_map, bindparams):
- cache_key = (self.fn.__code__, self.__class__)
+ cache_key = (
+ self.fn.__code__,
+ self.__class__,
+ ) + self.closure_cache_key
+
+ parent = self.parent_lambda
+ while parent is not None:
+ cache_key = (
+ (parent.fn.__code__,) + parent.closure_cache_key + cache_key
+ )
+
+ parent = parent.parent_lambda
if self._resolved_bindparams:
bindparams.extend(self._resolved_bindparams)
@@ -211,101 +274,51 @@ class LambdaElement(elements.ClauseElement):
def _invoke_user_fn(self, fn, *arg):
return fn()
- def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw):
- fn = self.fn
- # track objects referenced inside of lambdas, create bindparams
- # ahead of time for literal values. If bindparams are produced,
- # then rewrite the function globals and closure as necessary so that
- # it refers to the bindparams, then invoke the function
- new_closure = {}
- new_globals = fn.__globals__.copy()
- tracker_collection = []
- check_closure_for_stale = []
+class DeferredLambdaElement(LambdaElement):
+ """A LambdaElement where the lambda accepts arguments and is
+ invoked within the compile phase with special context.
- for name in fn.__code__.co_names:
- if name not in new_globals:
- continue
+ This lambda doesn't normally produce its real SQL expression outside of the
+ compile phase. It is passed a fixed set of initial arguments
+ so that it can generate a sample expression.
- bound_value = _roll_down_to_literal(new_globals[name])
+ """
- if coercions._is_literal(bound_value):
- new_globals[name] = bind = PyWrapper(name, bound_value)
- tracker_collection.append(_globals_tracker(name, bind))
+ def __init__(self, fn, role, lambda_args=(), **kw):
+ self.lambda_args = lambda_args
+ self.coerce_kw = kw
+ super(DeferredLambdaElement, self).__init__(fn, role, **kw)
- if fn.__closure__:
- for closure_index, (fv, cell) in enumerate(
- zip(fn.__code__.co_freevars, fn.__closure__)
- ):
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(*self.lambda_args)
- bound_value = _roll_down_to_literal(cell.cell_contents)
+ def _resolve_with_args(self, *lambda_args):
+ tracker_fn = self._rec.tracker_instrumented_fn
+ expr = tracker_fn(*lambda_args)
- if coercions._is_literal(bound_value):
- new_closure[fv] = bind = PyWrapper(fv, bound_value)
- tracker_collection.append(
- _closure_tracker(fv, bind, closure_index)
- )
- else:
- new_closure[fv] = cell.cell_contents
- # for normal cell contents, add them to a list that
- # we can compare later when we get new lambdas. if
- # any identities have changed, then we will recalculate
- # the whole lambda and run it again.
- check_closure_for_stale.append(
- (closure_index, cell.cell_contents)
- )
+ expr = coercions.expect(self.role, expr, **self.coerce_kw)
- if tracker_collection:
- new_fn = _rewrite_code_obj(
- fn,
- [new_closure[name] for name in fn.__code__.co_freevars],
- new_globals,
- )
- expr = self._invoke_user_fn(new_fn)
+ if self._resolved_bindparams:
+ expr = self._setup_binds_for_tracked_expr(expr)
- else:
- new_fn = fn
- expr = self._invoke_user_fn(new_fn)
- tracker_collection = []
+ # TODO: TEST TEST TEST, this is very out there
+ for deferred_copy_internals in self._transforms:
+ expr = deferred_copy_internals(expr)
- if self.parent_lambda is None:
- if isinstance(expr, collections_abc.Sequence):
- expected_expr = [
- coercions.expect(
- role,
- sub_expr,
- apply_propagate_attrs=apply_propagate_attrs,
- **coerce_kw
- )
- for sub_expr in expr
- ]
- is_sequence = True
- else:
- expected_expr = coercions.expect(
- role,
- expr,
- apply_propagate_attrs=apply_propagate_attrs,
- **coerce_kw
- )
- is_sequence = False
- else:
- expected_expr = expr
- is_sequence = False
+ return expr
- if apply_propagate_attrs is not None:
- propagate_attrs = apply_propagate_attrs._propagate_attrs
- else:
- propagate_attrs = util.immutabledict()
-
- rec = _trackers[self.fn.__code__] = (
- tracker_collection,
- check_closure_for_stale,
- new_fn,
- expected_expr,
- is_sequence,
- propagate_attrs,
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ super(DeferredLambdaElement, self)._copy_internals(
+ clone=clone, deferred_copy_internals=deferred_copy_internals, **kw
)
- return rec
+
+ # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
+ # our expression yet. so hold onto the replacement
+ if deferred_copy_internals:
+ self._transforms += (deferred_copy_internals,)
class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
@@ -334,13 +347,38 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
"""
+ def __init__(self, fn, parent_lambda, **kw):
+ self._default_kw = default_kw = {}
+ global_track_bound_values = kw.pop("global_track_bound_values", None)
+ if global_track_bound_values is not None:
+ default_kw["track_bound_values"] = global_track_bound_values
+ kw["track_bound_values"] = global_track_bound_values
+
+ if "lambda_cache" in kw:
+ default_kw["lambda_cache"] = kw["lambda_cache"]
+
+ super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw)
+
def __add__(self, other):
- return LinkedLambdaElement(other, parent_lambda=self)
+ return LinkedLambdaElement(
+ other, parent_lambda=self, **self._default_kw
+ )
+
+ def add_criteria(self, other, **kw):
+ if self._default_kw:
+ if kw:
+ default_kw = self._default_kw.copy()
+ default_kw.update(kw)
+ kw = default_kw
+ else:
+ kw = self._default_kw
+
+ return LinkedLambdaElement(other, parent_lambda=self, **kw)
def _execute_on_connection(
self, connection, multiparams, params, execution_options
):
- if self._rec[_EXPR].supports_execution:
+ if self._rec.expected_expr.supports_execution:
return connection._execute_clauseelement(
self, multiparams, params, execution_options
)
@@ -349,93 +387,579 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
@property
def _with_options(self):
- return self._rec[_EXPR]._with_options
+ return self._rec.expected_expr._with_options
@property
def _effective_plugin_target(self):
- return self._rec[_EXPR]._effective_plugin_target
+ return self._rec.expected_expr._effective_plugin_target
@property
def _is_future(self):
- return self._rec[_EXPR]._is_future
+ return self._rec.expected_expr._is_future
@property
def _execution_options(self):
- return self._rec[_EXPR]._execution_options
+ return self._rec.expected_expr._execution_options
+
+ def spoil(self):
+ """Return a new :class:`.StatementLambdaElement` that will run
+ all lambdas unconditionally each time.
+
+ """
+ return NullLambdaStatement(self.fn())
+
+
+class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
+ """Provides the :class:`.StatementLambdaElement` API but does not
+ cache or analyze lambdas.
+
+ the lambdas are instead invoked immediately.
+
+ The intended use is to isolate issues that may arise when using
+ lambda statements.
+
+ """
+
+ __visit_name__ = "lambda_element"
+
+ _is_lambda_element = True
+
+ _traverse_internals = [
+ ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+ ]
+
+ def __init__(self, statement):
+ self._resolved = statement
+ self._propagate_attrs = statement._propagate_attrs
+
+ def __getattr__(self, key):
+ return getattr(self._resolved, key)
+
+ def __add__(self, other):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def add_criteria(self, other, **kw):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self._resolved.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
class LinkedLambdaElement(StatementLambdaElement):
+ """Represent subsequent links of a :class:`.StatementLambdaElement`."""
+
+ role = None
+
def __init__(self, fn, parent_lambda, **kw):
+ self._default_kw = parent_lambda._default_kw
+
self.fn = fn
self.parent_lambda = parent_lambda
- role = None
- apply_propagate_attrs = self
+ self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
+ self._retrieve_tracker_rec(fn, self, kw)
+ self._propagate_attrs = parent_lambda._propagate_attrs
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(self.parent_lambda._resolved)
+
+
+class AnalyzedCode(object):
+ __slots__ = (
+ "track_closure_variables",
+ "track_bound_values",
+ "bindparam_trackers",
+ "closure_trackers",
+ "build_py_wrappers",
+ )
+ _fns = weakref.WeakKeyDictionary()
+
+ @classmethod
+ def get(cls, fn, lambda_element, lambda_kw, **kw):
+ try:
+ # TODO: validate kw haven't changed?
+ return cls._fns[fn.__code__]
+ except KeyError:
+ pass
+ cls._fns[fn.__code__] = analyzed = AnalyzedCode(
+ fn, lambda_element, lambda_kw, **kw
+ )
+ return analyzed
+
+ def __init__(
+ self,
+ fn,
+ lambda_element,
+ lambda_kw,
+ track_bound_values=True,
+ enable_tracking=True,
+ track_on=None,
+ ):
+ closure = fn.__closure__
+
+ self.track_closure_variables = not track_on
+
+ self.track_bound_values = track_bound_values
+
+ # a list of callables generated from _bound_parameter_getter_*
+ # functions. Each of these uses a PyWrapper object to retrieve
+ # a parameter value
+ self.bindparam_trackers = []
+
+ # a list of callables generated from _cache_key_getter_* functions
+ # these callables work to generate a cache key for the lambda
+ # based on what's inside its closure variables.
+ self.closure_trackers = []
- if fn.__code__ not in _trackers:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ self.build_py_wrappers = []
+
+ if enable_tracking:
+ if track_on:
+ self._init_track_on(track_on)
+
+ self._init_globals(fn)
+
+ if closure:
+ self._init_closure(fn)
+
+ self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw)
+
+ def _init_track_on(self, track_on):
+ self.closure_trackers.extend(
+ self._cache_key_getter_track_on(idx, elem)
+ for idx, elem in enumerate(track_on)
+ )
+
+ def _init_globals(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ bindparam_trackers = self.bindparam_trackers
+ track_bound_values = self.track_bound_values
+
+ for name in fn.__code__.co_names:
+ if name not in fn.__globals__:
+ continue
+
+ _bound_value = self._roll_down_to_literal(fn.__globals__[name])
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((name, None))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_globals(name)
+ )
+
+ def _init_closure(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ closure = fn.__closure__
+
+ track_bound_values = self.track_bound_values
+ track_closure_variables = self.track_closure_variables
+ bindparam_trackers = self.bindparam_trackers
+ closure_trackers = self.closure_trackers
+
+ for closure_index, (fv, cell) in enumerate(
+ zip(fn.__code__.co_freevars, closure)
+ ):
+ _bound_value = self._roll_down_to_literal(cell.cell_contents)
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((fv, closure_index))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_closure(
+ fv, closure_index
+ )
+ )
+ else:
+ # for normal cell contents, add them to a list that
+ # we can compare later when we get new lambdas. if
+ # any identities have changed, then we will
+ # recalculate the whole lambda and run it again.
+
+ if track_closure_variables:
+ closure_trackers.append(
+ self._cache_key_getter_closure_variable(
+ closure_index, cell.cell_contents
+ )
+ )
+
+ def _setup_additional_closure_trackers(
+ self, fn, lambda_element, lambda_kw
+ ):
+ # an additional step is to actually run the function, then
+ # go through the PyWrapper objects that were set up to catch a bound
+ # parameter. then if they *didn't* make a param, oh they're another
+ # object in the closure we have to track for our cache key. so
+ # create trackers to catch those.
+
+ analyzed_function = AnalyzedFunction(
+ self, lambda_element, None, lambda_kw, fn,
+ )
+
+ closure_trackers = self.closure_trackers
+
+ for pywrapper in analyzed_function.closure_pywrappers:
+ if not pywrapper._sa__has_param:
+ closure_trackers.append(
+ self._cache_key_getter_tracked_literal(pywrapper)
+ )
+
+ @classmethod
+ def _roll_down_to_literal(cls, element):
+ is_clause_element = hasattr(element, "__clause_element__")
+
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
+
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ try:
+ return insp.__clause_element__()
+ except AttributeError:
+ return insp
+
+ # TODO: should we coerce consts None/True/False here?
+ return element
+ else:
+ return element
+
+ def _bound_parameter_getter_func_globals(self, name):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__globals__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__globals__[name]
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__globals__[name], result
+ )
+
+ return extract_parameter_value
+
+ def _bound_parameter_getter_func_closure(self, name, closure_index):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__closure__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__closure__[
+ closure_index
+ ].cell_contents
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__closure__[closure_index].cell_contents, result
+ )
+
+ return extract_parameter_value
+
+ def _cache_key_getter_track_on(self, idx, elem):
+ """Return a getter that will extend a cache key with new entries
+ from the "track_on" parameter passed to a :class:`.LambdaElement`.
+
+ """
+ if isinstance(elem, traversals.HasCacheKey):
+
+ def get(closure, kw, anon_map, bindparams):
+ return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams)
+
+ else:
+
+ def get(closure, kw, anon_map, bindparams):
+ return kw["track_on"][idx]
+
+ return get
+
+ def _cache_key_getter_closure_variable(self, idx, cell_contents):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ """
+
+ if isinstance(cell_contents, traversals.HasCacheKey):
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents._gen_cache_key(
+ anon_map, bindparams
+ )
+
+ elif isinstance(cell_contents, types.FunctionType):
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents.__code__
+
+ elif cell_contents.__hash__ is None:
+ # this covers dict, etc.
+ def get(closure, kw, anon_map, bindparams):
+ return ()
+
+ else:
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents
+
+ return get
+
+ def _cache_key_getter_tracked_literal(self, pytracker):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ this getter differs from _cache_key_getter_closure_variable
+ in that these are detected after the function is run, and PyWrapper
+ objects have recorded that a particular literal value is in fact
+ not being interpreted as a bound parameter.
+
+ """
+
+ elem = pytracker._sa__to_evaluate
+ closure_index = pytracker._sa__closure_index
+
+ if isinstance(elem, set):
+ raise exc.ArgumentError(
+ "Can't create a cache key for lambda closure variable "
+ '"%s" because it\'s a set. try using a list'
+ % pytracker._sa__name
)
+
+ elif isinstance(elem, list):
+
+ def get(closure, kw, anon_map, bindparams):
+ return tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in closure[closure_index].cell_contents
+ )
+
+ elif elem.__hash__ is None:
+ # this covers dict, etc.
+ def get(closure, kw, anon_map, bindparams):
+ return ()
+
else:
- rec = _trackers[self.fn.__code__]
+ def get(closure, kw, anon_map, bindparams):
+ return closure[closure_index].cell_contents
+
+ return get
+
+
+class AnalyzedFunction(object):
+ __slots__ = (
+ "analyzed_code",
+ "fn",
+ "closure_pywrappers",
+ "tracker_instrumented_fn",
+ "expr",
+ "bindparam_trackers",
+ "expected_expr",
+ "is_sequence",
+ "propagate_attrs",
+ "closure_bindparams",
+ )
+
+ def __init__(
+ self, analyzed_code, lambda_element, apply_propagate_attrs, kw, fn,
+ ):
+ self.analyzed_code = analyzed_code
+ self.fn = fn
+
+ self.bindparam_trackers = analyzed_code.bindparam_trackers
+
+ self._instrument_and_run_function(lambda_element)
+
+ self._coerce_expression(lambda_element, apply_propagate_attrs, kw)
+
+ def _instrument_and_run_function(self, lambda_element):
+ analyzed_code = self.analyzed_code
+
+ fn = self.fn
+ self.closure_pywrappers = closure_pywrappers = []
+
+ build_py_wrappers = analyzed_code.build_py_wrappers
+
+ if not build_py_wrappers:
+ self.tracker_instrumented_fn = tracker_instrumented_fn = fn
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
+ else:
+ track_closure_variables = analyzed_code.track_closure_variables
closure = fn.__closure__
- # check if objects referred to by the lambda have changed and
- # re-scan the lambda if so. see comments for this same section in
- # LambdaElement.
- for idx, obj in rec[_STALE_CHECK]:
- if closure[idx].cell_contents is not obj:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ # will form the __closure__ of the function when we rebuild it
+ if closure:
+ new_closure = {
+ fv: cell.cell_contents
+ for fv, cell in zip(fn.__code__.co_freevars, closure)
+ }
+ else:
+ new_closure = {}
+
+ # will form the __globals__ of the function when we rebuild it
+ new_globals = fn.__globals__.copy()
+
+ for name, closure_index in build_py_wrappers:
+ if closure_index is not None:
+ value = closure[closure_index].cell_contents
+ new_closure[name] = bind = PyWrapper(
+ name, value, closure_index=closure_index
)
- break
+ if track_closure_variables:
+ closure_pywrappers.append(bind)
+ else:
+ value = fn.__globals__[name]
+ new_globals[name] = bind = PyWrapper(name, value)
+
+ # rewrite the original fn. things that look like they will
+ # become bound parameters are wrapped in a PyWrapper.
+ self.tracker_instrumented_fn = (
+ tracker_instrumented_fn
+ ) = self._rewrite_code_obj(
+ fn,
+ [new_closure[name] for name in fn.__code__.co_freevars],
+ new_globals,
+ )
- self._rec = rec
+ # now invoke the function. This will give us a new SQL
+ # expression, but all the places that there would be a bound
+ # parameter, the PyWrapper in its place will give us a bind
+ # with a predictable name we can match up later.
- self._propagate_attrs = parent_lambda._propagate_attrs
+ # additionally, each PyWrapper will log that it did in fact
+ # create a parameter, otherwise, it's some kind of Python
+ # object in the closure and we want to track that, to make
+ # sure it doesn't change to somehting else, or if it does,
+ # that we create a different tracked function with that
+ # variable.
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
- self._resolved_bindparams = bindparams = []
- rec = self._rec
- while True:
- if rec[_TRACKERS]:
- for tracker in rec[_TRACKERS]:
- tracker(self.fn, bindparams)
- if self.parent_lambda is not None:
- self = self.parent_lambda
- rec = self._rec
+ def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw):
+ """Run the tracker-generated expression through coercion rules.
+
+ After the user-defined lambda has been invoked to produce a statement
+ for re-use, run it through coercion rules to both check that it's the
+ correct type of object and also to coerce it to its useful form.
+
+ """
+
+ parent_lambda = lambda_element.parent_lambda
+ expr = self.expr
+
+ if parent_lambda is None:
+ if isinstance(expr, collections_abc.Sequence):
+ self.expected_expr = [
+ coercions.expect(
+ lambda_element.role,
+ sub_expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ **kw
+ )
+ for sub_expr in expr
+ ]
+ self.is_sequence = True
else:
- break
+ self.expected_expr = coercions.expect(
+ lambda_element.role,
+ expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ **kw
+ )
+ self.is_sequence = False
+ else:
+ self.expected_expr = expr
+ self.is_sequence = False
- def _invoke_user_fn(self, fn, *arg):
- return fn(self.parent_lambda._rec[_EXPR])
+ if apply_propagate_attrs is not None:
+ self.propagate_attrs = apply_propagate_attrs._propagate_attrs
+ else:
+ self.propagate_attrs = util.EMPTY_DICT
- def _gen_cache_key(self, anon_map, bindparams):
- if self._resolved_bindparams:
- bindparams.extend(self._resolved_bindparams)
+ def _rewrite_code_obj(self, f, cell_values, globals_):
+ """Return a copy of f, with a new closure and new globals
- cache_key = (self.fn.__code__, self.__class__)
+ yes it works in pypy :P
- parent = self.parent_lambda
- while parent is not None:
- cache_key = (parent.fn.__code__,) + cache_key
- parent = parent.parent_lambda
+ """
- return cache_key
+ argrange = range(len(cell_values))
+
+ code = "def make_cells():\n"
+ if cell_values:
+ code += " (%s) = (%s)\n" % (
+ ", ".join("i%d" % i for i in argrange),
+ ", ".join("o%d" % i for i in argrange),
+ )
+ code += " def closure():\n"
+ code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
+ code += " return closure.__closure__"
+ vars_ = {"o%d" % i: cell_values[i] for i in argrange}
+ exec(code, vars_, vars_)
+ closure = vars_["make_cells"]()
+
+ func = type(f)(
+ f.__code__, globals_, f.__name__, f.__defaults__, closure
+ )
+ if sys.version_info >= (3,):
+ func.__annotations__ = f.__annotations__
+ func.__kwdefaults__ = f.__kwdefaults__
+ func.__doc__ = f.__doc__
+ func.__module__ = f.__module__
+
+ return func
class PyWrapper(ColumnOperators):
- def __init__(self, name, to_evaluate, getter=None):
+ """A wrapper object that is injected into the ``__globals__`` and
+ ``__closure__`` of a Python function.
+
+ When the function is instrumented with :class:`.PyWrapper` objects, it is
+ then invoked just once in order to set up the wrappers. We look through
+ all the :class:`.PyWrapper` objects we made to find the ones that generated
+ a :class:`.BindParameter` object, e.g. the expression system interpreted
+ something as a literal. Those positions in the globals/closure are then
+ ones that we will look at, each time a new lambda comes in that refers to
+ the same ``__code__`` object. In this way, we keep a single version of
+ the SQL expression that this lambda produced, without calling upon the
+ Python function that created it more than once, unless its other closure
+ variables have changed. The expression is then transformed to have the
+ new bound values embedded into it.
+
+ """
+
+ def __init__(self, name, to_evaluate, closure_index=None, getter=None):
self._name = name
self._to_evaluate = to_evaluate
self._param = None
+ self._has_param = False
self._bind_paths = {}
self._getter = getter
+ self._closure_index = closure_index
def __call__(self, *arg, **kw):
elem = object.__getattribute__(self, "_to_evaluate")
value = elem(*arg, **kw)
- if coercions._is_literal(value) and not isinstance(
+ if coercions._deep_is_literal(value) and not isinstance(
# TODO: coverage where an ORM option or similar is here
value,
traversals.HasCacheKey,
@@ -481,8 +1005,8 @@ class PyWrapper(ColumnOperators):
if param is None:
name = object.__getattribute__(self, "_name")
self._param = param = elements.BindParameter(name, unique=True)
+ self._has_param = True
param.type = type_api._resolve_value_to_type(to_evaluate)
-
return param._with_value(to_evaluate, maintain_key=True)
def __getattribute__(self, key):
@@ -497,7 +1021,15 @@ class PyWrapper(ColumnOperators):
else:
return self._sa__add_getter(key, operator.attrgetter)
+ def __iter__(self):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ return iter(elem)
+
def __getitem__(self, key):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ if not hasattr(elem, "__getitem__"):
+ raise AttributeError("__getitem__")
+
if isinstance(key, PyWrapper):
# TODO: coverage
raise exc.InvalidRequestError(
@@ -518,90 +1050,14 @@ class PyWrapper(ColumnOperators):
elem = object.__getattribute__(self, "_to_evaluate")
value = getter(elem)
- if coercions._is_literal(value):
- wrapper = PyWrapper(key, value, getter)
+ if coercions._deep_is_literal(value):
+ wrapper = PyWrapper(key, value, getter=getter)
bind_paths[bind_path_key] = wrapper
return wrapper
else:
return value
-def _roll_down_to_literal(element):
- is_clause_element = hasattr(element, "__clause_element__")
-
- if is_clause_element:
- while not isinstance(
- element, (elements.ClauseElement, schema.SchemaItem)
- ):
- try:
- element = element.__clause_element__()
- except AttributeError:
- break
-
- if not is_clause_element:
- insp = inspection.inspect(element, raiseerr=False)
- if insp is not None:
- try:
- return insp.__clause_element__()
- except AttributeError:
- return insp
-
- # TODO: should we coerce consts None/True/False here?
- return element
- else:
- return element
-
-
-def _globals_tracker(name, wrapper):
- def extract_parameter_value(current_fn, result):
- object.__getattribute__(wrapper, "_extract_bound_parameters")(
- current_fn.__globals__[name], result
- )
-
- return extract_parameter_value
-
-
-def _closure_tracker(name, wrapper, closure_index):
- def extract_parameter_value(current_fn, result):
- object.__getattribute__(wrapper, "_extract_bound_parameters")(
- current_fn.__closure__[closure_index].cell_contents, result
- )
-
- return extract_parameter_value
-
-
-def _rewrite_code_obj(f, cell_values, globals_):
- """Return a copy of f, with a new closure and new globals
-
- yes it works in pypy :P
-
- """
-
- argrange = range(len(cell_values))
-
- code = "def make_cells():\n"
- if cell_values:
- code += " (%s) = (%s)\n" % (
- ", ".join("i%d" % i for i in argrange),
- ", ".join("o%d" % i for i in argrange),
- )
- code += " def closure():\n"
- code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
- code += " return closure.__closure__"
- vars_ = {"o%d" % i: cell_values[i] for i in argrange}
- exec(code, vars_, vars_)
- closure = vars_["make_cells"]()
-
- func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure)
- if sys.version_info >= (3,):
- func.__annotations__ = f.__annotations__
- func.__kwdefaults__ = f.__kwdefaults__
- func.__doc__ = f.__doc__
- func.__module__ = f.__module__
-
- return func
-
-
@inspection._inspects(LambdaElement)
def insp(lmb):
return inspection.inspect(lmb._resolved)