diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-08-05 16:42:26 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-08-05 16:42:26 -0400 |
commit | cc57ea495f6460dd56daa6de57e40047ed999369 (patch) | |
tree | 837f5a84363c387d7f8fdeabc06928cd078028e1 /lib/sqlalchemy/sql/lambdas.py | |
parent | 2a946254023135eddd222974cf300ffaa5583f02 (diff) | |
download | sqlalchemy-cc57ea495f6460dd56daa6de57e40047ed999369.tar.gz |
Robustness for lambdas, lambda statements
in order to accommodate relationship loaders
with lambda caching, a lot more is needed. This is
a full refactor of the lambda system such that it
now has two levels of caching; the first level caches what
can be known from the __code__ element, then the next level
of caching is against the lambda itself and the contents
of __closure__. This allows for the elements inside
the lambdas, like columns and entities, to change and
then be part of the cache key. Lazy/selectinloads' use of
baked queries had to add distinct cache key elements,
which was attempted here but overall things needed to be
more robust than that.
This commit is broken out from the very long and sprawling
commit at Id6b5c03b1ce9ddb7b280f66792212a0ef0a1c541 .
Change-Id: I29a513c98917b1d503abfdd61e6b6e8800851aa8
Diffstat (limited to 'lib/sqlalchemy/sql/lambdas.py')
-rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 980 |
1 files changed, 718 insertions, 262 deletions
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 792411189..327003902 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -8,6 +8,7 @@ import itertools import operator import sys +import types import weakref from . import coercions @@ -17,29 +18,23 @@ from . import schema from . import traversals from . import type_api from . import visitors +from .base import _clone from .operators import ColumnOperators from .. import exc from .. import inspection from .. import util from ..util import collections_abc -_trackers = weakref.WeakKeyDictionary() +_closure_per_cache_key = util.LRUCache(1000) -_TRACKERS = 0 -_STALE_CHECK = 1 -_REAL_FN = 2 -_EXPR = 3 -_IS_SEQUENCE = 4 -_PROPAGATE_ATTRS = 5 - - -def lambda_stmt(lmb): +def lambda_stmt(lmb, **opts): """Produce a SQL statement that is cached as a lambda. - This SQL statement will only be constructed if element has not been - compiled yet. The approach is used to save on Python function overhead - when constructing statements that will be cached. + The Python code object within the lambda is scanned for both Python + literals that will become bound parameters as well as closure variables + that refer to Core or ORM constructs that may vary. The lambda itself + will be invoked only once per particular set of constructs detected. E.g.:: @@ -60,7 +55,8 @@ def lambda_stmt(lmb): """ - return coercions.expect(roles.CoerceTextStatementRole, lmb) + + return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts) class LambdaElement(elements.ClauseElement): @@ -87,64 +83,108 @@ class LambdaElement(elements.ClauseElement): _is_lambda_element = True - _resolved_bindparams = () - _traverse_internals = [ ("_resolved", visitors.InternalTraversal.dp_clauseelement) ] + _transforms = () + + parent_lambda = None + def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) def __init__(self, fn, role, apply_propagate_attrs=None, **kw): self.fn = fn self.role = role - self.parent_lambda = None + self.tracker_key = (fn.__code__,) if apply_propagate_attrs is None and ( role is roles.CoerceTextStatementRole ): apply_propagate_attrs = self - if fn.__code__ not in _trackers: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw) + + if apply_propagate_attrs is not None: + propagate_attrs = rec.propagate_attrs + if propagate_attrs: + apply_propagate_attrs._propagate_attrs = propagate_attrs + + def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw): + lambda_cache = kw.get("lambda_cache", _closure_per_cache_key) + + tracker_key = self.tracker_key + + fn = self.fn + closure = fn.__closure__ + + tracker = AnalyzedCode.get( + fn, + self, + kw, + track_bound_values=kw.get("track_bound_values", True), + enable_tracking=kw.get("enable_tracking", True), + track_on=kw.get("track_on", None), + ) + + self._resolved_bindparams = bindparams = [] + + anon_map = traversals.anon_map() + cache_key = tuple( + [ + getter(closure, kw, anon_map, bindparams) + for getter in tracker.closure_trackers + ] + ) + if self.parent_lambda is not None: + cache_key = self.parent_lambda.closure_cache_key + cache_key + + self.closure_cache_key = cache_key + + try: + rec = lambda_cache[tracker_key + cache_key] + except KeyError: + rec = None + + if rec is None: + rec = AnalyzedFunction( + tracker, self, apply_propagate_attrs, kw, fn ) + rec.closure_bindparams = bindparams + lambda_cache[tracker_key + cache_key] = rec else: - rec = _trackers[self.fn.__code__] - closure = fn.__closure__ + bindparams[:] = [ + orig_bind._with_value(new_bind.value, maintain_key=True) + for orig_bind, new_bind in zip( + rec.closure_bindparams, bindparams + ) + ] + + if self.parent_lambda is not None: + bindparams[:0] = self.parent_lambda._resolved_bindparams - # check if the objects fixed inside the lambda that we've cached - # have been changed. This can apply to things like mappers that - # were recreated in test suites. if so, re-initialize. - # - # this is a small performance hit on every use for a not very - # common situation, however it's very hard to debug if the - # condition does occur. - for idx, obj in rec[_STALE_CHECK]: - if closure[idx].cell_contents is not obj: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw - ) - break self._rec = rec - if apply_propagate_attrs is not None: - propagate_attrs = rec[_PROPAGATE_ATTRS] - if propagate_attrs: - apply_propagate_attrs._propagate_attrs = propagate_attrs + lambda_element = self + while lambda_element is not None: + rec = lambda_element._rec + if rec.bindparam_trackers: + tracker_instrumented_fn = rec.tracker_instrumented_fn + for tracker in rec.bindparam_trackers: + tracker( + lambda_element.fn, tracker_instrumented_fn, bindparams + ) + lambda_element = lambda_element.parent_lambda - if rec[_TRACKERS]: - self._resolved_bindparams = bindparams = [] - for tracker in rec[_TRACKERS]: - tracker(self.fn, bindparams) + return rec def __getattr__(self, key): - return getattr(self._rec[_EXPR], key) + return getattr(self._rec.expected_expr, key) @property def _is_sequence(self): - return self._rec[_IS_SEQUENCE] + return self._rec.is_sequence @property def _select_iterable(self): @@ -169,8 +209,7 @@ class LambdaElement(elements.ClauseElement): def _param_dict(self): return {b.key: b.value for b in self._resolved_bindparams} - @util.memoized_property - def _resolved(self): + def _setup_binds_for_tracked_expr(self, expr): bindparam_lookup = {b.key: b for b in self._resolved_bindparams} def replace(thing): @@ -179,17 +218,11 @@ class LambdaElement(elements.ClauseElement): and thing.key in bindparam_lookup ): bind = bindparam_lookup[thing.key] - # TODO: consider - # if we should clone the bindparam here, re-cache the new - # version, etc. also we make an assumption about "expanding" - # in this case. if thing.expanding: bind.expanding = True return bind - expr = self._rec[_EXPR] - - if self._rec[_IS_SEQUENCE]: + if self._rec.is_sequence: expr = [ visitors.replacement_traverse(sub_expr, {}, replace) for sub_expr in expr @@ -199,9 +232,39 @@ class LambdaElement(elements.ClauseElement): return expr + def _copy_internals( + self, clone=_clone, deferred_copy_internals=None, **kw + ): + # TODO: this needs A LOT of tests + self._resolved = clone( + self._resolved, + deferred_copy_internals=deferred_copy_internals, + **kw + ) + + @util.memoized_property + def _resolved(self): + expr = self._rec.expected_expr + + if self._resolved_bindparams: + expr = self._setup_binds_for_tracked_expr(expr) + + return expr + def _gen_cache_key(self, anon_map, bindparams): - cache_key = (self.fn.__code__, self.__class__) + cache_key = ( + self.fn.__code__, + self.__class__, + ) + self.closure_cache_key + + parent = self.parent_lambda + while parent is not None: + cache_key = ( + (parent.fn.__code__,) + parent.closure_cache_key + cache_key + ) + + parent = parent.parent_lambda if self._resolved_bindparams: bindparams.extend(self._resolved_bindparams) @@ -211,101 +274,51 @@ class LambdaElement(elements.ClauseElement): def _invoke_user_fn(self, fn, *arg): return fn() - def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw): - fn = self.fn - # track objects referenced inside of lambdas, create bindparams - # ahead of time for literal values. If bindparams are produced, - # then rewrite the function globals and closure as necessary so that - # it refers to the bindparams, then invoke the function - new_closure = {} - new_globals = fn.__globals__.copy() - tracker_collection = [] - check_closure_for_stale = [] +class DeferredLambdaElement(LambdaElement): + """A LambdaElement where the lambda accepts arguments and is + invoked within the compile phase with special context. - for name in fn.__code__.co_names: - if name not in new_globals: - continue + This lambda doesn't normally produce its real SQL expression outside of the + compile phase. It is passed a fixed set of initial arguments + so that it can generate a sample expression. - bound_value = _roll_down_to_literal(new_globals[name]) + """ - if coercions._is_literal(bound_value): - new_globals[name] = bind = PyWrapper(name, bound_value) - tracker_collection.append(_globals_tracker(name, bind)) + def __init__(self, fn, role, lambda_args=(), **kw): + self.lambda_args = lambda_args + self.coerce_kw = kw + super(DeferredLambdaElement, self).__init__(fn, role, **kw) - if fn.__closure__: - for closure_index, (fv, cell) in enumerate( - zip(fn.__code__.co_freevars, fn.__closure__) - ): + def _invoke_user_fn(self, fn, *arg): + return fn(*self.lambda_args) - bound_value = _roll_down_to_literal(cell.cell_contents) + def _resolve_with_args(self, *lambda_args): + tracker_fn = self._rec.tracker_instrumented_fn + expr = tracker_fn(*lambda_args) - if coercions._is_literal(bound_value): - new_closure[fv] = bind = PyWrapper(fv, bound_value) - tracker_collection.append( - _closure_tracker(fv, bind, closure_index) - ) - else: - new_closure[fv] = cell.cell_contents - # for normal cell contents, add them to a list that - # we can compare later when we get new lambdas. if - # any identities have changed, then we will recalculate - # the whole lambda and run it again. - check_closure_for_stale.append( - (closure_index, cell.cell_contents) - ) + expr = coercions.expect(self.role, expr, **self.coerce_kw) - if tracker_collection: - new_fn = _rewrite_code_obj( - fn, - [new_closure[name] for name in fn.__code__.co_freevars], - new_globals, - ) - expr = self._invoke_user_fn(new_fn) + if self._resolved_bindparams: + expr = self._setup_binds_for_tracked_expr(expr) - else: - new_fn = fn - expr = self._invoke_user_fn(new_fn) - tracker_collection = [] + # TODO: TEST TEST TEST, this is very out there + for deferred_copy_internals in self._transforms: + expr = deferred_copy_internals(expr) - if self.parent_lambda is None: - if isinstance(expr, collections_abc.Sequence): - expected_expr = [ - coercions.expect( - role, - sub_expr, - apply_propagate_attrs=apply_propagate_attrs, - **coerce_kw - ) - for sub_expr in expr - ] - is_sequence = True - else: - expected_expr = coercions.expect( - role, - expr, - apply_propagate_attrs=apply_propagate_attrs, - **coerce_kw - ) - is_sequence = False - else: - expected_expr = expr - is_sequence = False + return expr - if apply_propagate_attrs is not None: - propagate_attrs = apply_propagate_attrs._propagate_attrs - else: - propagate_attrs = util.immutabledict() - - rec = _trackers[self.fn.__code__] = ( - tracker_collection, - check_closure_for_stale, - new_fn, - expected_expr, - is_sequence, - propagate_attrs, + def _copy_internals( + self, clone=_clone, deferred_copy_internals=None, **kw + ): + super(DeferredLambdaElement, self)._copy_internals( + clone=clone, deferred_copy_internals=deferred_copy_internals, **kw ) - return rec + + # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know + # our expression yet. so hold onto the replacement + if deferred_copy_internals: + self._transforms += (deferred_copy_internals,) class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): @@ -334,13 +347,38 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ + def __init__(self, fn, parent_lambda, **kw): + self._default_kw = default_kw = {} + global_track_bound_values = kw.pop("global_track_bound_values", None) + if global_track_bound_values is not None: + default_kw["track_bound_values"] = global_track_bound_values + kw["track_bound_values"] = global_track_bound_values + + if "lambda_cache" in kw: + default_kw["lambda_cache"] = kw["lambda_cache"] + + super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw) + def __add__(self, other): - return LinkedLambdaElement(other, parent_lambda=self) + return LinkedLambdaElement( + other, parent_lambda=self, **self._default_kw + ) + + def add_criteria(self, other, **kw): + if self._default_kw: + if kw: + default_kw = self._default_kw.copy() + default_kw.update(kw) + kw = default_kw + else: + kw = self._default_kw + + return LinkedLambdaElement(other, parent_lambda=self, **kw) def _execute_on_connection( self, connection, multiparams, params, execution_options ): - if self._rec[_EXPR].supports_execution: + if self._rec.expected_expr.supports_execution: return connection._execute_clauseelement( self, multiparams, params, execution_options ) @@ -349,93 +387,579 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): @property def _with_options(self): - return self._rec[_EXPR]._with_options + return self._rec.expected_expr._with_options @property def _effective_plugin_target(self): - return self._rec[_EXPR]._effective_plugin_target + return self._rec.expected_expr._effective_plugin_target @property def _is_future(self): - return self._rec[_EXPR]._is_future + return self._rec.expected_expr._is_future @property def _execution_options(self): - return self._rec[_EXPR]._execution_options + return self._rec.expected_expr._execution_options + + def spoil(self): + """Return a new :class:`.StatementLambdaElement` that will run + all lambdas unconditionally each time. + + """ + return NullLambdaStatement(self.fn()) + + +class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): + """Provides the :class:`.StatementLambdaElement` API but does not + cache or analyze lambdas. + + the lambdas are instead invoked immediately. + + The intended use is to isolate issues that may arise when using + lambda statements. + + """ + + __visit_name__ = "lambda_element" + + _is_lambda_element = True + + _traverse_internals = [ + ("_resolved", visitors.InternalTraversal.dp_clauseelement) + ] + + def __init__(self, statement): + self._resolved = statement + self._propagate_attrs = statement._propagate_attrs + + def __getattr__(self, key): + return getattr(self._resolved, key) + + def __add__(self, other): + statement = other(self._resolved) + + return NullLambdaStatement(statement) + + def add_criteria(self, other, **kw): + statement = other(self._resolved) + + return NullLambdaStatement(statement) + + def _execute_on_connection( + self, connection, multiparams, params, execution_options + ): + if self._resolved.supports_execution: + return connection._execute_clauseelement( + self, multiparams, params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self) class LinkedLambdaElement(StatementLambdaElement): + """Represent subsequent links of a :class:`.StatementLambdaElement`.""" + + role = None + def __init__(self, fn, parent_lambda, **kw): + self._default_kw = parent_lambda._default_kw + self.fn = fn self.parent_lambda = parent_lambda - role = None - apply_propagate_attrs = self + self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) + self._retrieve_tracker_rec(fn, self, kw) + self._propagate_attrs = parent_lambda._propagate_attrs + + def _invoke_user_fn(self, fn, *arg): + return fn(self.parent_lambda._resolved) + + +class AnalyzedCode(object): + __slots__ = ( + "track_closure_variables", + "track_bound_values", + "bindparam_trackers", + "closure_trackers", + "build_py_wrappers", + ) + _fns = weakref.WeakKeyDictionary() + + @classmethod + def get(cls, fn, lambda_element, lambda_kw, **kw): + try: + # TODO: validate kw haven't changed? + return cls._fns[fn.__code__] + except KeyError: + pass + cls._fns[fn.__code__] = analyzed = AnalyzedCode( + fn, lambda_element, lambda_kw, **kw + ) + return analyzed + + def __init__( + self, + fn, + lambda_element, + lambda_kw, + track_bound_values=True, + enable_tracking=True, + track_on=None, + ): + closure = fn.__closure__ + + self.track_closure_variables = not track_on + + self.track_bound_values = track_bound_values + + # a list of callables generated from _bound_parameter_getter_* + # functions. Each of these uses a PyWrapper object to retrieve + # a parameter value + self.bindparam_trackers = [] + + # a list of callables generated from _cache_key_getter_* functions + # these callables work to generate a cache key for the lambda + # based on what's inside its closure variables. + self.closure_trackers = [] - if fn.__code__ not in _trackers: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + self.build_py_wrappers = [] + + if enable_tracking: + if track_on: + self._init_track_on(track_on) + + self._init_globals(fn) + + if closure: + self._init_closure(fn) + + self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw) + + def _init_track_on(self, track_on): + self.closure_trackers.extend( + self._cache_key_getter_track_on(idx, elem) + for idx, elem in enumerate(track_on) + ) + + def _init_globals(self, fn): + build_py_wrappers = self.build_py_wrappers + bindparam_trackers = self.bindparam_trackers + track_bound_values = self.track_bound_values + + for name in fn.__code__.co_names: + if name not in fn.__globals__: + continue + + _bound_value = self._roll_down_to_literal(fn.__globals__[name]) + + if coercions._deep_is_literal(_bound_value): + build_py_wrappers.append((name, None)) + if track_bound_values: + bindparam_trackers.append( + self._bound_parameter_getter_func_globals(name) + ) + + def _init_closure(self, fn): + build_py_wrappers = self.build_py_wrappers + closure = fn.__closure__ + + track_bound_values = self.track_bound_values + track_closure_variables = self.track_closure_variables + bindparam_trackers = self.bindparam_trackers + closure_trackers = self.closure_trackers + + for closure_index, (fv, cell) in enumerate( + zip(fn.__code__.co_freevars, closure) + ): + _bound_value = self._roll_down_to_literal(cell.cell_contents) + + if coercions._deep_is_literal(_bound_value): + build_py_wrappers.append((fv, closure_index)) + if track_bound_values: + bindparam_trackers.append( + self._bound_parameter_getter_func_closure( + fv, closure_index + ) + ) + else: + # for normal cell contents, add them to a list that + # we can compare later when we get new lambdas. if + # any identities have changed, then we will + # recalculate the whole lambda and run it again. + + if track_closure_variables: + closure_trackers.append( + self._cache_key_getter_closure_variable( + closure_index, cell.cell_contents + ) + ) + + def _setup_additional_closure_trackers( + self, fn, lambda_element, lambda_kw + ): + # an additional step is to actually run the function, then + # go through the PyWrapper objects that were set up to catch a bound + # parameter. then if they *didn't* make a param, oh they're another + # object in the closure we have to track for our cache key. so + # create trackers to catch those. + + analyzed_function = AnalyzedFunction( + self, lambda_element, None, lambda_kw, fn, + ) + + closure_trackers = self.closure_trackers + + for pywrapper in analyzed_function.closure_pywrappers: + if not pywrapper._sa__has_param: + closure_trackers.append( + self._cache_key_getter_tracked_literal(pywrapper) + ) + + @classmethod + def _roll_down_to_literal(cls, element): + is_clause_element = hasattr(element, "__clause_element__") + + if is_clause_element: + while not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + try: + element = element.__clause_element__() + except AttributeError: + break + + if not is_clause_element: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + try: + return insp.__clause_element__() + except AttributeError: + return insp + + # TODO: should we coerce consts None/True/False here? + return element + else: + return element + + def _bound_parameter_getter_func_globals(self, name): + """Return a getter that will extend a list of bound parameters + with new entries from the ``__globals__`` collection of a particular + lambda. + + """ + + def extract_parameter_value( + current_fn, tracker_instrumented_fn, result + ): + wrapper = tracker_instrumented_fn.__globals__[name] + object.__getattribute__(wrapper, "_extract_bound_parameters")( + current_fn.__globals__[name], result + ) + + return extract_parameter_value + + def _bound_parameter_getter_func_closure(self, name, closure_index): + """Return a getter that will extend a list of bound parameters + with new entries from the ``__closure__`` collection of a particular + lambda. + + """ + + def extract_parameter_value( + current_fn, tracker_instrumented_fn, result + ): + wrapper = tracker_instrumented_fn.__closure__[ + closure_index + ].cell_contents + object.__getattribute__(wrapper, "_extract_bound_parameters")( + current_fn.__closure__[closure_index].cell_contents, result + ) + + return extract_parameter_value + + def _cache_key_getter_track_on(self, idx, elem): + """Return a getter that will extend a cache key with new entries + from the "track_on" parameter passed to a :class:`.LambdaElement`. + + """ + if isinstance(elem, traversals.HasCacheKey): + + def get(closure, kw, anon_map, bindparams): + return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams) + + else: + + def get(closure, kw, anon_map, bindparams): + return kw["track_on"][idx] + + return get + + def _cache_key_getter_closure_variable(self, idx, cell_contents): + """Return a getter that will extend a cache key with new entries + from the ``__closure__`` collection of a particular lambda. + + """ + + if isinstance(cell_contents, traversals.HasCacheKey): + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents._gen_cache_key( + anon_map, bindparams + ) + + elif isinstance(cell_contents, types.FunctionType): + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents.__code__ + + elif cell_contents.__hash__ is None: + # this covers dict, etc. + def get(closure, kw, anon_map, bindparams): + return () + + else: + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents + + return get + + def _cache_key_getter_tracked_literal(self, pytracker): + """Return a getter that will extend a cache key with new entries + from the ``__closure__`` collection of a particular lambda. + + this getter differs from _cache_key_getter_closure_variable + in that these are detected after the function is run, and PyWrapper + objects have recorded that a particular literal value is in fact + not being interpreted as a bound parameter. + + """ + + elem = pytracker._sa__to_evaluate + closure_index = pytracker._sa__closure_index + + if isinstance(elem, set): + raise exc.ArgumentError( + "Can't create a cache key for lambda closure variable " + '"%s" because it\'s a set. try using a list' + % pytracker._sa__name ) + + elif isinstance(elem, list): + + def get(closure, kw, anon_map, bindparams): + return tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in closure[closure_index].cell_contents + ) + + elif elem.__hash__ is None: + # this covers dict, etc. + def get(closure, kw, anon_map, bindparams): + return () + else: - rec = _trackers[self.fn.__code__] + def get(closure, kw, anon_map, bindparams): + return closure[closure_index].cell_contents + + return get + + +class AnalyzedFunction(object): + __slots__ = ( + "analyzed_code", + "fn", + "closure_pywrappers", + "tracker_instrumented_fn", + "expr", + "bindparam_trackers", + "expected_expr", + "is_sequence", + "propagate_attrs", + "closure_bindparams", + ) + + def __init__( + self, analyzed_code, lambda_element, apply_propagate_attrs, kw, fn, + ): + self.analyzed_code = analyzed_code + self.fn = fn + + self.bindparam_trackers = analyzed_code.bindparam_trackers + + self._instrument_and_run_function(lambda_element) + + self._coerce_expression(lambda_element, apply_propagate_attrs, kw) + + def _instrument_and_run_function(self, lambda_element): + analyzed_code = self.analyzed_code + + fn = self.fn + self.closure_pywrappers = closure_pywrappers = [] + + build_py_wrappers = analyzed_code.build_py_wrappers + + if not build_py_wrappers: + self.tracker_instrumented_fn = tracker_instrumented_fn = fn + self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) + else: + track_closure_variables = analyzed_code.track_closure_variables closure = fn.__closure__ - # check if objects referred to by the lambda have changed and - # re-scan the lambda if so. see comments for this same section in - # LambdaElement. - for idx, obj in rec[_STALE_CHECK]: - if closure[idx].cell_contents is not obj: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + # will form the __closure__ of the function when we rebuild it + if closure: + new_closure = { + fv: cell.cell_contents + for fv, cell in zip(fn.__code__.co_freevars, closure) + } + else: + new_closure = {} + + # will form the __globals__ of the function when we rebuild it + new_globals = fn.__globals__.copy() + + for name, closure_index in build_py_wrappers: + if closure_index is not None: + value = closure[closure_index].cell_contents + new_closure[name] = bind = PyWrapper( + name, value, closure_index=closure_index ) - break + if track_closure_variables: + closure_pywrappers.append(bind) + else: + value = fn.__globals__[name] + new_globals[name] = bind = PyWrapper(name, value) + + # rewrite the original fn. things that look like they will + # become bound parameters are wrapped in a PyWrapper. + self.tracker_instrumented_fn = ( + tracker_instrumented_fn + ) = self._rewrite_code_obj( + fn, + [new_closure[name] for name in fn.__code__.co_freevars], + new_globals, + ) - self._rec = rec + # now invoke the function. This will give us a new SQL + # expression, but all the places that there would be a bound + # parameter, the PyWrapper in its place will give us a bind + # with a predictable name we can match up later. - self._propagate_attrs = parent_lambda._propagate_attrs + # additionally, each PyWrapper will log that it did in fact + # create a parameter, otherwise, it's some kind of Python + # object in the closure and we want to track that, to make + # sure it doesn't change to somehting else, or if it does, + # that we create a different tracked function with that + # variable. + self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) - self._resolved_bindparams = bindparams = [] - rec = self._rec - while True: - if rec[_TRACKERS]: - for tracker in rec[_TRACKERS]: - tracker(self.fn, bindparams) - if self.parent_lambda is not None: - self = self.parent_lambda - rec = self._rec + def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw): + """Run the tracker-generated expression through coercion rules. + + After the user-defined lambda has been invoked to produce a statement + for re-use, run it through coercion rules to both check that it's the + correct type of object and also to coerce it to its useful form. + + """ + + parent_lambda = lambda_element.parent_lambda + expr = self.expr + + if parent_lambda is None: + if isinstance(expr, collections_abc.Sequence): + self.expected_expr = [ + coercions.expect( + lambda_element.role, + sub_expr, + apply_propagate_attrs=apply_propagate_attrs, + **kw + ) + for sub_expr in expr + ] + self.is_sequence = True else: - break + self.expected_expr = coercions.expect( + lambda_element.role, + expr, + apply_propagate_attrs=apply_propagate_attrs, + **kw + ) + self.is_sequence = False + else: + self.expected_expr = expr + self.is_sequence = False - def _invoke_user_fn(self, fn, *arg): - return fn(self.parent_lambda._rec[_EXPR]) + if apply_propagate_attrs is not None: + self.propagate_attrs = apply_propagate_attrs._propagate_attrs + else: + self.propagate_attrs = util.EMPTY_DICT - def _gen_cache_key(self, anon_map, bindparams): - if self._resolved_bindparams: - bindparams.extend(self._resolved_bindparams) + def _rewrite_code_obj(self, f, cell_values, globals_): + """Return a copy of f, with a new closure and new globals - cache_key = (self.fn.__code__, self.__class__) + yes it works in pypy :P - parent = self.parent_lambda - while parent is not None: - cache_key = (parent.fn.__code__,) + cache_key - parent = parent.parent_lambda + """ - return cache_key + argrange = range(len(cell_values)) + + code = "def make_cells():\n" + if cell_values: + code += " (%s) = (%s)\n" % ( + ", ".join("i%d" % i for i in argrange), + ", ".join("o%d" % i for i in argrange), + ) + code += " def closure():\n" + code += " return %s\n" % ", ".join("i%d" % i for i in argrange) + code += " return closure.__closure__" + vars_ = {"o%d" % i: cell_values[i] for i in argrange} + exec(code, vars_, vars_) + closure = vars_["make_cells"]() + + func = type(f)( + f.__code__, globals_, f.__name__, f.__defaults__, closure + ) + if sys.version_info >= (3,): + func.__annotations__ = f.__annotations__ + func.__kwdefaults__ = f.__kwdefaults__ + func.__doc__ = f.__doc__ + func.__module__ = f.__module__ + + return func class PyWrapper(ColumnOperators): - def __init__(self, name, to_evaluate, getter=None): + """A wrapper object that is injected into the ``__globals__`` and + ``__closure__`` of a Python function. + + When the function is instrumented with :class:`.PyWrapper` objects, it is + then invoked just once in order to set up the wrappers. We look through + all the :class:`.PyWrapper` objects we made to find the ones that generated + a :class:`.BindParameter` object, e.g. the expression system interpreted + something as a literal. Those positions in the globals/closure are then + ones that we will look at, each time a new lambda comes in that refers to + the same ``__code__`` object. In this way, we keep a single version of + the SQL expression that this lambda produced, without calling upon the + Python function that created it more than once, unless its other closure + variables have changed. The expression is then transformed to have the + new bound values embedded into it. + + """ + + def __init__(self, name, to_evaluate, closure_index=None, getter=None): self._name = name self._to_evaluate = to_evaluate self._param = None + self._has_param = False self._bind_paths = {} self._getter = getter + self._closure_index = closure_index def __call__(self, *arg, **kw): elem = object.__getattribute__(self, "_to_evaluate") value = elem(*arg, **kw) - if coercions._is_literal(value) and not isinstance( + if coercions._deep_is_literal(value) and not isinstance( # TODO: coverage where an ORM option or similar is here value, traversals.HasCacheKey, @@ -481,8 +1005,8 @@ class PyWrapper(ColumnOperators): if param is None: name = object.__getattribute__(self, "_name") self._param = param = elements.BindParameter(name, unique=True) + self._has_param = True param.type = type_api._resolve_value_to_type(to_evaluate) - return param._with_value(to_evaluate, maintain_key=True) def __getattribute__(self, key): @@ -497,7 +1021,15 @@ class PyWrapper(ColumnOperators): else: return self._sa__add_getter(key, operator.attrgetter) + def __iter__(self): + elem = object.__getattribute__(self, "_to_evaluate") + return iter(elem) + def __getitem__(self, key): + elem = object.__getattribute__(self, "_to_evaluate") + if not hasattr(elem, "__getitem__"): + raise AttributeError("__getitem__") + if isinstance(key, PyWrapper): # TODO: coverage raise exc.InvalidRequestError( @@ -518,90 +1050,14 @@ class PyWrapper(ColumnOperators): elem = object.__getattribute__(self, "_to_evaluate") value = getter(elem) - if coercions._is_literal(value): - wrapper = PyWrapper(key, value, getter) + if coercions._deep_is_literal(value): + wrapper = PyWrapper(key, value, getter=getter) bind_paths[bind_path_key] = wrapper return wrapper else: return value -def _roll_down_to_literal(element): - is_clause_element = hasattr(element, "__clause_element__") - - if is_clause_element: - while not isinstance( - element, (elements.ClauseElement, schema.SchemaItem) - ): - try: - element = element.__clause_element__() - except AttributeError: - break - - if not is_clause_element: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - try: - return insp.__clause_element__() - except AttributeError: - return insp - - # TODO: should we coerce consts None/True/False here? - return element - else: - return element - - -def _globals_tracker(name, wrapper): - def extract_parameter_value(current_fn, result): - object.__getattribute__(wrapper, "_extract_bound_parameters")( - current_fn.__globals__[name], result - ) - - return extract_parameter_value - - -def _closure_tracker(name, wrapper, closure_index): - def extract_parameter_value(current_fn, result): - object.__getattribute__(wrapper, "_extract_bound_parameters")( - current_fn.__closure__[closure_index].cell_contents, result - ) - - return extract_parameter_value - - -def _rewrite_code_obj(f, cell_values, globals_): - """Return a copy of f, with a new closure and new globals - - yes it works in pypy :P - - """ - - argrange = range(len(cell_values)) - - code = "def make_cells():\n" - if cell_values: - code += " (%s) = (%s)\n" % ( - ", ".join("i%d" % i for i in argrange), - ", ".join("o%d" % i for i in argrange), - ) - code += " def closure():\n" - code += " return %s\n" % ", ".join("i%d" % i for i in argrange) - code += " return closure.__closure__" - vars_ = {"o%d" % i: cell_values[i] for i in argrange} - exec(code, vars_, vars_) - closure = vars_["make_cells"]() - - func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure) - if sys.version_info >= (3,): - func.__annotations__ = f.__annotations__ - func.__kwdefaults__ = f.__kwdefaults__ - func.__doc__ = f.__doc__ - func.__module__ = f.__module__ - - return func - - @inspection._inspects(LambdaElement) def insp(lmb): return inspection.inspect(lmb._resolved) |