diff options
Diffstat (limited to 'lib/sqlalchemy/sql/lambdas.py')
-rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 456 |
1 files changed, 312 insertions, 144 deletions
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index aafdda4ce..3f0ca477e 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -19,6 +19,7 @@ from . import traversals from . import type_api from . import visitors from .base import _clone +from .base import Options from .operators import ColumnOperators from .. import exc from .. import inspection @@ -28,7 +29,24 @@ from ..util import collections_abc _closure_per_cache_key = util.LRUCache(1000) -def lambda_stmt(lmb, **opts): +class LambdaOptions(Options): + enable_tracking = True + track_closure_variables = True + track_on = None + global_track_bound_values = True + track_bound_values = True + lambda_cache = None + + +def lambda_stmt( + lmb, + enable_tracking=True, + track_closure_variables=True, + track_on=None, + global_track_bound_values=True, + track_bound_values=True, + lambda_cache=None, +): """Produce a SQL statement that is cached as a lambda. The Python code object within the lambda is scanned for both Python @@ -49,6 +67,29 @@ def lambda_stmt(lmb, **opts): .. versionadded:: 1.4 + :param lmb: a Python function, typically a lambda, which takes no arguments + and returns a SQL expression construct + :param enable_tracking: when False, all scanning of the given lambda for + changes in closure variables or bound parameters is disabled. Use for + a lambda that produces the identical results in all cases with no + parameterization. + :param track_closure_variables: when False, changes in closure variables + within the lambda will not be scanned. Use for a lambda where the + state of its closure variables will never change the SQL structure + returned by the lambda. + :param track_bound_values: when False, bound parameter tracking will + be disabled for the given lambda. Use for a lambda that either does + not produce any bound values, or where the initial bound values never + change. + :param global_track_bound_values: when False, bound parameter tracking + will be disabled for the entire statement including additional links + added via the :meth:`_sql.StatementLambdaElement.add_criteria` method. + :param lambda_cache: a dictionary or other mapping-like object where + information about the lambda's Python code as well as the tracked closure + variables in the lambda itself will be stored. Defaults + to a global LRU cache. This cache is independent of the "compiled_cache" + used by the :class:`_engine.Connection` object. + .. seealso:: :ref:`engine_lambda_caching` @@ -56,7 +97,18 @@ def lambda_stmt(lmb, **opts): """ - return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts) + return StatementLambdaElement( + lmb, + roles.CoerceTextStatementRole, + LambdaOptions( + enable_tracking=enable_tracking, + track_on=track_on, + track_closure_variables=track_closure_variables, + global_track_bound_values=global_track_bound_values, + track_bound_values=track_bound_values, + lambda_cache=lambda_cache, + ), + ) class LambdaElement(elements.ClauseElement): @@ -94,38 +146,39 @@ class LambdaElement(elements.ClauseElement): def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) - def __init__(self, fn, role, apply_propagate_attrs=None, **kw): + def __init__( + self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None + ): self.fn = fn self.role = role self.tracker_key = (fn.__code__,) + self.opts = opts if apply_propagate_attrs is None and ( role is roles.CoerceTextStatementRole ): apply_propagate_attrs = self - rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw) + rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts) if apply_propagate_attrs is not None: propagate_attrs = rec.propagate_attrs if propagate_attrs: apply_propagate_attrs._propagate_attrs = propagate_attrs - def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw): - lambda_cache = kw.get("lambda_cache", _closure_per_cache_key) + def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): + lambda_cache = opts.lambda_cache + if lambda_cache is None: + lambda_cache = _closure_per_cache_key tracker_key = self.tracker_key fn = self.fn closure = fn.__closure__ - tracker = AnalyzedCode.get( fn, self, - kw, - track_bound_values=kw.get("track_bound_values", True), - enable_tracking=kw.get("enable_tracking", True), - track_on=kw.get("track_on", None), + opts, ) self._resolved_bindparams = bindparams = [] @@ -133,10 +186,11 @@ class LambdaElement(elements.ClauseElement): anon_map = traversals.anon_map() cache_key = tuple( [ - getter(closure, kw, anon_map, bindparams) + getter(closure, opts, anon_map, bindparams) for getter in tracker.closure_trackers ] ) + if self.parent_lambda is not None: cache_key = self.parent_lambda.closure_cache_key + cache_key @@ -148,9 +202,7 @@ class LambdaElement(elements.ClauseElement): rec = None if rec is None: - rec = AnalyzedFunction( - tracker, self, apply_propagate_attrs, kw, fn - ) + rec = AnalyzedFunction(tracker, self, apply_propagate_attrs, fn) rec.closure_bindparams = bindparams lambda_cache[tracker_key + cache_key] = rec else: @@ -213,14 +265,13 @@ class LambdaElement(elements.ClauseElement): bindparam_lookup = {b.key: b for b in self._resolved_bindparams} def replace(thing): - if ( - isinstance(thing, elements.BindParameter) - and thing.key in bindparam_lookup - ): - bind = bindparam_lookup[thing.key] - if thing.expanding: - bind.expanding = True - return bind + if isinstance(thing, elements.BindParameter): + + if thing.key in bindparam_lookup: + bind = bindparam_lookup[thing.key] + if thing.expanding: + bind.expanding = True + return bind if self._rec.is_sequence: expr = [ @@ -268,7 +319,6 @@ class LambdaElement(elements.ClauseElement): if self._resolved_bindparams: bindparams.extend(self._resolved_bindparams) - return cache_key def _invoke_user_fn(self, fn, *arg): @@ -285,10 +335,9 @@ class DeferredLambdaElement(LambdaElement): """ - def __init__(self, fn, role, lambda_args=(), **kw): + def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()): self.lambda_args = lambda_args - self.coerce_kw = kw - super(DeferredLambdaElement, self).__init__(fn, role, **kw) + super(DeferredLambdaElement, self).__init__(fn, role, opts) def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) @@ -297,10 +346,30 @@ class DeferredLambdaElement(LambdaElement): tracker_fn = self._rec.tracker_instrumented_fn expr = tracker_fn(*lambda_args) - expr = coercions.expect(self.role, expr, **self.coerce_kw) - - if self._resolved_bindparams: - expr = self._setup_binds_for_tracked_expr(expr) + expr = coercions.expect(self.role, expr) + + expr = self._setup_binds_for_tracked_expr(expr) + + # this validation is getting very close, but not quite, to achieving + # #5767. The problem is if the base lambda uses an unnamed column + # as is very common with mixins, the parameter name is different + # and it produces a false positive; that is, for the documented case + # that is exactly what people will be doing, it doesn't work, so + # I'm not really sure how to handle this right now. + # expected_binds = [ + # b._orig_key + # for b in self._rec.expr._generate_cache_key()[1] + # if b.required + # ] + # got_binds = [ + # b._orig_key for b in expr._generate_cache_key()[1] if b.required + # ] + # if expected_binds != got_binds: + # raise exc.InvalidRequestError( + # "Lambda callable at %s produced a different set of bound " + # "parameters than its original run: %s" + # % (self.fn.__code__, ", ".join(got_binds)) + # ) # TODO: TEST TEST TEST, this is very out there for deferred_copy_internals in self._transforms: @@ -312,7 +381,9 @@ class DeferredLambdaElement(LambdaElement): self, clone=_clone, deferred_copy_internals=None, **kw ): super(DeferredLambdaElement, self)._copy_internals( - clone=clone, deferred_copy_internals=deferred_copy_internals, **kw + clone=clone, + deferred_copy_internals=deferred_copy_internals, # **kw + opts=kw, ) # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know @@ -347,33 +418,60 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ - def __init__(self, fn, parent_lambda, **kw): - self._default_kw = default_kw = {} - global_track_bound_values = kw.pop("global_track_bound_values", None) - if global_track_bound_values is not None: - default_kw["track_bound_values"] = global_track_bound_values - kw["track_bound_values"] = global_track_bound_values + def __add__(self, other): + return self.add_criteria(other) - if "lambda_cache" in kw: - default_kw["lambda_cache"] = kw["lambda_cache"] + def add_criteria( + self, + other, + enable_tracking=True, + track_on=None, + track_closure_variables=True, + track_bound_values=True, + ): + """Add new criteria to this :class:`_sql.StatementLambdaElement`. + + E.g.:: + + >>> def my_stmt(parameter): + ... stmt = lambda_stmt( + ... lambda: select(table.c.x, table.c.y), + ... ) + ... stmt = stmt.add_criteria( + ... lambda: table.c.x > parameter + ... ) + ... return stmt + + The :meth:`_sql.StatementLambdaElement.add_criteria` method is + equivalent to using the Python addition operator to add a new + lambda, except that additional arguments may be added including + ``track_closure_values`` and ``track_on``:: + + >>> def my_stmt(self, foo): + ... stmt = lambda_stmt( + ... lambda: select(func.max(foo.x, foo.y)), + ... track_closure_variables=False + ... ) + ... stmt = stmt.add_criteria( + ... lambda: self.where_criteria, + ... track_on=[self] + ... ) + ... return stmt + + See :func:`_sql.lambda_stmt` for a description of the parameters + accepted. - super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw) + """ - def __add__(self, other): - return LinkedLambdaElement( - other, parent_lambda=self, **self._default_kw + opts = self.opts + dict( + enable_tracking=enable_tracking, + track_closure_variables=track_closure_variables, + global_track_bound_values=self.opts.global_track_bound_values, + track_on=track_on, + track_bound_values=track_bound_values, ) - def add_criteria(self, other, **kw): - if self._default_kw: - if kw: - default_kw = self._default_kw.copy() - default_kw.update(kw) - kw = default_kw - else: - kw = self._default_kw - - return LinkedLambdaElement(other, parent_lambda=self, **kw) + return LinkedLambdaElement(other, parent_lambda=self, opts=opts) def _execute_on_connection( self, connection, multiparams, params, execution_options @@ -461,14 +559,13 @@ class LinkedLambdaElement(StatementLambdaElement): role = None - def __init__(self, fn, parent_lambda, **kw): - self._default_kw = parent_lambda._default_kw - + def __init__(self, fn, parent_lambda, opts): + self.opts = opts self.fn = fn self.parent_lambda = parent_lambda self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) - self._retrieve_tracker_rec(fn, self, kw) + self._retrieve_tracker_rec(fn, self, opts) self._propagate_attrs = parent_lambda._propagate_attrs def _invoke_user_fn(self, fn, *arg): @@ -497,20 +594,17 @@ class AnalyzedCode(object): ) return analyzed - def __init__( - self, - fn, - lambda_element, - lambda_kw, - track_bound_values=True, - enable_tracking=True, - track_on=None, - ): + def __init__(self, fn, lambda_element, opts): closure = fn.__closure__ - self.track_closure_variables = not track_on + self.track_bound_values = ( + opts.track_bound_values and opts.global_track_bound_values + ) + enable_tracking = opts.enable_tracking + track_on = opts.track_on + track_closure_variables = opts.track_closure_variables - self.track_bound_values = track_bound_values + self.track_closure_variables = track_closure_variables and not track_on # a list of callables generated from _bound_parameter_getter_* # functions. Each of these uses a PyWrapper object to retrieve @@ -533,7 +627,7 @@ class AnalyzedCode(object): if closure: self._init_closure(fn) - self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw) + self._setup_additional_closure_trackers(fn, lambda_element, opts) def _init_track_on(self, track_on): self.closure_trackers.extend( @@ -590,13 +684,11 @@ class AnalyzedCode(object): if track_closure_variables: closure_trackers.append( self._cache_key_getter_closure_variable( - closure_index, cell.cell_contents + fn, fv, closure_index, cell.cell_contents ) ) - def _setup_additional_closure_trackers( - self, fn, lambda_element, lambda_kw - ): + def _setup_additional_closure_trackers(self, fn, lambda_element, opts): # an additional step is to actually run the function, then # go through the PyWrapper objects that were set up to catch a bound # parameter. then if they *didn't* make a param, oh they're another @@ -607,7 +699,6 @@ class AnalyzedCode(object): self, lambda_element, None, - lambda_kw, fn, ) @@ -616,7 +707,7 @@ class AnalyzedCode(object): for pywrapper in analyzed_function.closure_pywrappers: if not pywrapper._sa__has_param: closure_trackers.append( - self._cache_key_getter_tracked_literal(pywrapper) + self._cache_key_getter_tracked_literal(fn, pywrapper) ) @classmethod @@ -625,7 +716,7 @@ class AnalyzedCode(object): if is_clause_element: while not isinstance( - element, (elements.ClauseElement, schema.SchemaItem) + element, (elements.ClauseElement, schema.SchemaItem, type) ): try: element = element.__clause_element__() @@ -688,17 +779,25 @@ class AnalyzedCode(object): """ if isinstance(elem, traversals.HasCacheKey): - def get(closure, kw, anon_map, bindparams): - return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams) + def get(closure, opts, anon_map, bindparams): + return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) else: - def get(closure, kw, anon_map, bindparams): - return kw["track_on"][idx] + def get(closure, opts, anon_map, bindparams): + return opts.track_on[idx] return get - def _cache_key_getter_closure_variable(self, idx, cell_contents): + def _cache_key_getter_closure_variable( + self, + fn, + variable_name, + idx, + cell_contents, + use_clause_element=False, + use_inspect=False, + ): """Return a getter that will extend a cache key with new entries from the ``__closure__`` collection of a particular lambda. @@ -706,29 +805,90 @@ class AnalyzedCode(object): if isinstance(cell_contents, traversals.HasCacheKey): - def get(closure, kw, anon_map, bindparams): - return closure[idx].cell_contents._gen_cache_key( - anon_map, bindparams - ) + def get(closure, opts, anon_map, bindparams): + + obj = closure[idx].cell_contents + if use_inspect: + obj = inspection.inspect(obj) + elif use_clause_element: + while hasattr(obj, "__clause_element__"): + if not getattr(obj, "is_clause_element", False): + obj = obj.__clause_element__() + + return obj._gen_cache_key(anon_map, bindparams) elif isinstance(cell_contents, types.FunctionType): - def get(closure, kw, anon_map, bindparams): + def get(closure, opts, anon_map, bindparams): return closure[idx].cell_contents.__code__ - elif cell_contents.__hash__ is None: - # this covers dict, etc. - def get(closure, kw, anon_map, bindparams): - return () + elif isinstance(cell_contents, collections_abc.Sequence): + + def get(closure, opts, anon_map, bindparams): + contents = closure[idx].cell_contents + + try: + return tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in contents + ) + except AttributeError as ae: + self._raise_for_uncacheable_closure_variable( + variable_name, fn, from_=ae + ) else: + # if the object is a mapped class or aliased class, or some + # other object in the ORM realm of things like that, imitate + # the logic used in coercions.expect() to roll it down to the + # SQL element + element = cell_contents + is_clause_element = False + while hasattr(element, "__clause_element__"): + is_clause_element = True + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break - def get(closure, kw, anon_map, bindparams): - return closure[idx].cell_contents + if not is_clause_element: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + return self._cache_key_getter_closure_variable( + fn, variable_name, idx, insp, use_inspect=True + ) + else: + return self._cache_key_getter_closure_variable( + fn, variable_name, idx, element, use_clause_element=True + ) + + self._raise_for_uncacheable_closure_variable(variable_name, fn) return get - def _cache_key_getter_tracked_literal(self, pytracker): + def _raise_for_uncacheable_closure_variable( + self, variable_name, fn, from_=None + ): + util.raise_( + exc.InvalidRequestError( + "Closure variable named '%s' inside of lambda callable %s " + "does not refer to a cachable SQL element, and also does not " + "appear to be serving as a SQL literal bound value based on " + "the default " + "SQL expression returned by the function. This variable " + "needs to remain outside the scope of a SQL-generating lambda " + "so that a proper cache key may be generated from the " + "lambda's state. Evaluate this variable outside of the " + "lambda, set track_on=[<elements>] to explicitly select " + "closure elements to track, or set " + "track_closure_variables=False to exclude " + "closure variables from being part of the cache key." + % (variable_name, fn.__code__), + ), + from_=from_, + ) + + def _cache_key_getter_tracked_literal(self, fn, pytracker): """Return a getter that will extend a cache key with new entries from the ``__closure__`` collection of a particular lambda. @@ -741,33 +901,11 @@ class AnalyzedCode(object): elem = pytracker._sa__to_evaluate closure_index = pytracker._sa__closure_index + variable_name = pytracker._sa__name - if isinstance(elem, set): - raise exc.ArgumentError( - "Can't create a cache key for lambda closure variable " - '"%s" because it\'s a set. try using a list' - % pytracker._sa__name - ) - - elif isinstance(elem, list): - - def get(closure, kw, anon_map, bindparams): - return tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in closure[closure_index].cell_contents - ) - - elif elem.__hash__ is None: - # this covers dict, etc. - def get(closure, kw, anon_map, bindparams): - return () - - else: - - def get(closure, kw, anon_map, bindparams): - return closure[closure_index].cell_contents - - return get + return self._cache_key_getter_closure_variable( + fn, variable_name, closure_index, elem + ) class AnalyzedFunction(object): @@ -789,7 +927,6 @@ class AnalyzedFunction(object): analyzed_code, lambda_element, apply_propagate_attrs, - kw, fn, ): self.analyzed_code = analyzed_code @@ -799,7 +936,7 @@ class AnalyzedFunction(object): self._instrument_and_run_function(lambda_element) - self._coerce_expression(lambda_element, apply_propagate_attrs, kw) + self._coerce_expression(lambda_element, apply_propagate_attrs) def _instrument_and_run_function(self, lambda_element): analyzed_code = self.analyzed_code @@ -832,13 +969,19 @@ class AnalyzedFunction(object): if closure_index is not None: value = closure[closure_index].cell_contents new_closure[name] = bind = PyWrapper( - name, value, closure_index=closure_index + fn, + name, + value, + closure_index=closure_index, + track_bound_values=( + self.analyzed_code.track_bound_values + ), ) if track_closure_variables: closure_pywrappers.append(bind) else: value = fn.__globals__[name] - new_globals[name] = bind = PyWrapper(name, value) + new_globals[name] = bind = PyWrapper(fn, name, value) # rewrite the original fn. things that look like they will # become bound parameters are wrapped in a PyWrapper. @@ -863,7 +1006,7 @@ class AnalyzedFunction(object): # variable. self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) - def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw): + def _coerce_expression(self, lambda_element, apply_propagate_attrs): """Run the tracker-generated expression through coercion rules. After the user-defined lambda has been invoked to produce a statement @@ -882,7 +1025,6 @@ class AnalyzedFunction(object): lambda_element.role, sub_expr, apply_propagate_attrs=apply_propagate_attrs, - **kw ) for sub_expr in expr ] @@ -892,7 +1034,6 @@ class AnalyzedFunction(object): lambda_element.role, expr, apply_propagate_attrs=apply_propagate_attrs, - **kw ) self.is_sequence = False else: @@ -956,7 +1097,16 @@ class PyWrapper(ColumnOperators): """ - def __init__(self, name, to_evaluate, closure_index=None, getter=None): + def __init__( + self, + fn, + name, + to_evaluate, + closure_index=None, + getter=None, + track_bound_values=True, + ): + self.fn = fn self._name = name self._to_evaluate = to_evaluate self._param = None @@ -964,28 +1114,35 @@ class PyWrapper(ColumnOperators): self._bind_paths = {} self._getter = getter self._closure_index = closure_index + self.track_bound_values = track_bound_values def __call__(self, *arg, **kw): elem = object.__getattribute__(self, "_to_evaluate") value = elem(*arg, **kw) - if coercions._deep_is_literal(value) and not isinstance( - # TODO: coverage where an ORM option or similar is here - value, - traversals.HasCacheKey, + if ( + self._sa_track_bound_values + and coercions._deep_is_literal(value) + and not isinstance( + # TODO: coverage where an ORM option or similar is here + value, + traversals.HasCacheKey, + ) ): - # TODO: we can instead scan the arguments and make sure they - # are all Python literals - - # TODO: coverage name = object.__getattribute__(self, "_name") raise exc.InvalidRequestError( "Can't invoke Python callable %s() inside of lambda " - "expression argument; lambda cache keys should not call " - "regular functions since the caching " - "system does not track the values of the arguments passed " - "to the functions. Call the function outside of the lambda " - "and assign to a local variable that is used in the lambda." - % (name) + "expression argument at %s; lambda SQL constructs should " + "not invoke functions from closure variables to produce " + "literal values since the " + "lambda SQL system normally extracts bound values without " + "actually " + "invoking the lambda or any functions within it. Call the " + "function outside of the " + "lambda and assign to a local variable that is used in the " + "lambda as a closure variable, or set " + "track_bound_values=False if the return value of this " + "function is used in some other way other than a SQL bound " + "value." % (name, self._sa_fn.__code__) ) else: return value @@ -1018,6 +1175,14 @@ class PyWrapper(ColumnOperators): param.type = type_api._resolve_value_to_type(to_evaluate) return param._with_value(to_evaluate, maintain_key=True) + def __bool__(self): + to_evaluate = object.__getattribute__(self, "_to_evaluate") + return bool(to_evaluate) + + def __nonzero__(self): + to_evaluate = object.__getattribute__(self, "_to_evaluate") + return bool(to_evaluate) + def __getattribute__(self, key): if key.startswith("_sa_"): return object.__getattribute__(self, key[4:]) @@ -1026,6 +1191,7 @@ class PyWrapper(ColumnOperators): "operate", "reverse_operate", "__class__", + "__dict__", ): return object.__getattribute__(self, key) @@ -1064,8 +1230,10 @@ class PyWrapper(ColumnOperators): elem = object.__getattribute__(self, "_to_evaluate") value = getter(elem) - if coercions._deep_is_literal(value): - wrapper = PyWrapper(key, value, getter=getter) + rolled_down_value = AnalyzedCode._roll_down_to_literal(value) + + if coercions._deep_is_literal(rolled_down_value): + wrapper = PyWrapper(self._sa_fn, key, value, getter=getter) bind_paths[bind_path_key] = wrapper return wrapper else: |