diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-12-12 18:56:58 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-12-16 18:50:47 -0500 |
commit | 77c9534dcaf3723f7b2baf42442eda3e1d8c3332 (patch) | |
tree | c8f46c2a936c08a8de1f156807a0a3f31dd9486c /lib/sqlalchemy/sql/lambdas.py | |
parent | 09fac89debfbdcccbf2bcc433f7bec7921cf62be (diff) | |
download | sqlalchemy-77c9534dcaf3723f7b2baf42442eda3e1d8c3332.tar.gz |
Major revisals to lambdas
1. Improve coercions._deep_is_literal to check sequences
for clause elements, thus allowing a phrase like
lambda: col.in_([literal("x"), literal("y")]) to be handled
2. revise closure variable caching completely. All variables
entering must be part of a closure cache key or rejected.
only objects that can be resolved to HasCacheKey or FunctionType
are accepted; all other types are rejected. This adds a high
degree of strictness to lambdas and will make them a little more
awkward to use in some cases, however prevents several classes
of critical issues:
a. previously, a lambda that had an expression derived from
some kind of state, like "self.x", or "execution_context.session.foo"
would produce a closure cache key from "self" or "execution_context",
objects that can very well be per-execution and would therefore
cause a AnalyzedFunction objects to overflow. (memory won't leak
as it looks like an LRUCache is already used for these)
b. a lambda, such as one used within DeferredLamdaElement, that
produces different SQL expressions based on the arguments
(which is in fact what it's supposed to do), however it would
through the use of conditionals produce different bound parameter
combinations, leading to literal parameters not tracked properly.
These are now rejected as uncacheable whereas previously they would
again be part of the closure cache key, causing an overflow of
AnalyizedFunction objects.
3. Ensure non-mapped mixins are handled correctly by
with_loader_criteria().
4. Fixed bug in lambda SQL system where we are not supposed to allow a Python
function to be embedded in the lambda, since we can't predict a bound value
from it. While there was an error condition added for this, it was not
tested and wasn't working; an informative error is now raised.
5. new docs for lambdas
6. consolidated changelog for all of these
Fixes: #5760
Fixes: #5765
Fixes: #5766
Fixes: #5768
Fixes: #5770
Change-Id: Iedaa636c3225fad496df23b612c516c8ab247ab7
Diffstat (limited to 'lib/sqlalchemy/sql/lambdas.py')
-rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 456 |
1 files changed, 312 insertions, 144 deletions
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index aafdda4ce..3f0ca477e 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -19,6 +19,7 @@ from . import traversals from . import type_api from . import visitors from .base import _clone +from .base import Options from .operators import ColumnOperators from .. import exc from .. import inspection @@ -28,7 +29,24 @@ from ..util import collections_abc _closure_per_cache_key = util.LRUCache(1000) -def lambda_stmt(lmb, **opts): +class LambdaOptions(Options): + enable_tracking = True + track_closure_variables = True + track_on = None + global_track_bound_values = True + track_bound_values = True + lambda_cache = None + + +def lambda_stmt( + lmb, + enable_tracking=True, + track_closure_variables=True, + track_on=None, + global_track_bound_values=True, + track_bound_values=True, + lambda_cache=None, +): """Produce a SQL statement that is cached as a lambda. The Python code object within the lambda is scanned for both Python @@ -49,6 +67,29 @@ def lambda_stmt(lmb, **opts): .. versionadded:: 1.4 + :param lmb: a Python function, typically a lambda, which takes no arguments + and returns a SQL expression construct + :param enable_tracking: when False, all scanning of the given lambda for + changes in closure variables or bound parameters is disabled. Use for + a lambda that produces the identical results in all cases with no + parameterization. + :param track_closure_variables: when False, changes in closure variables + within the lambda will not be scanned. Use for a lambda where the + state of its closure variables will never change the SQL structure + returned by the lambda. + :param track_bound_values: when False, bound parameter tracking will + be disabled for the given lambda. Use for a lambda that either does + not produce any bound values, or where the initial bound values never + change. + :param global_track_bound_values: when False, bound parameter tracking + will be disabled for the entire statement including additional links + added via the :meth:`_sql.StatementLambdaElement.add_criteria` method. + :param lambda_cache: a dictionary or other mapping-like object where + information about the lambda's Python code as well as the tracked closure + variables in the lambda itself will be stored. Defaults + to a global LRU cache. This cache is independent of the "compiled_cache" + used by the :class:`_engine.Connection` object. + .. seealso:: :ref:`engine_lambda_caching` @@ -56,7 +97,18 @@ def lambda_stmt(lmb, **opts): """ - return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts) + return StatementLambdaElement( + lmb, + roles.CoerceTextStatementRole, + LambdaOptions( + enable_tracking=enable_tracking, + track_on=track_on, + track_closure_variables=track_closure_variables, + global_track_bound_values=global_track_bound_values, + track_bound_values=track_bound_values, + lambda_cache=lambda_cache, + ), + ) class LambdaElement(elements.ClauseElement): @@ -94,38 +146,39 @@ class LambdaElement(elements.ClauseElement): def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) - def __init__(self, fn, role, apply_propagate_attrs=None, **kw): + def __init__( + self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None + ): self.fn = fn self.role = role self.tracker_key = (fn.__code__,) + self.opts = opts if apply_propagate_attrs is None and ( role is roles.CoerceTextStatementRole ): apply_propagate_attrs = self - rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw) + rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts) if apply_propagate_attrs is not None: propagate_attrs = rec.propagate_attrs if propagate_attrs: apply_propagate_attrs._propagate_attrs = propagate_attrs - def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw): - lambda_cache = kw.get("lambda_cache", _closure_per_cache_key) + def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): + lambda_cache = opts.lambda_cache + if lambda_cache is None: + lambda_cache = _closure_per_cache_key tracker_key = self.tracker_key fn = self.fn closure = fn.__closure__ - tracker = AnalyzedCode.get( fn, self, - kw, - track_bound_values=kw.get("track_bound_values", True), - enable_tracking=kw.get("enable_tracking", True), - track_on=kw.get("track_on", None), + opts, ) self._resolved_bindparams = bindparams = [] @@ -133,10 +186,11 @@ class LambdaElement(elements.ClauseElement): anon_map = traversals.anon_map() cache_key = tuple( [ - getter(closure, kw, anon_map, bindparams) + getter(closure, opts, anon_map, bindparams) for getter in tracker.closure_trackers ] ) + if self.parent_lambda is not None: cache_key = self.parent_lambda.closure_cache_key + cache_key @@ -148,9 +202,7 @@ class LambdaElement(elements.ClauseElement): rec = None if rec is None: - rec = AnalyzedFunction( - tracker, self, apply_propagate_attrs, kw, fn - ) + rec = AnalyzedFunction(tracker, self, apply_propagate_attrs, fn) rec.closure_bindparams = bindparams lambda_cache[tracker_key + cache_key] = rec else: @@ -213,14 +265,13 @@ class LambdaElement(elements.ClauseElement): bindparam_lookup = {b.key: b for b in self._resolved_bindparams} def replace(thing): - if ( - isinstance(thing, elements.BindParameter) - and thing.key in bindparam_lookup - ): - bind = bindparam_lookup[thing.key] - if thing.expanding: - bind.expanding = True - return bind + if isinstance(thing, elements.BindParameter): + + if thing.key in bindparam_lookup: + bind = bindparam_lookup[thing.key] + if thing.expanding: + bind.expanding = True + return bind if self._rec.is_sequence: expr = [ @@ -268,7 +319,6 @@ class LambdaElement(elements.ClauseElement): if self._resolved_bindparams: bindparams.extend(self._resolved_bindparams) - return cache_key def _invoke_user_fn(self, fn, *arg): @@ -285,10 +335,9 @@ class DeferredLambdaElement(LambdaElement): """ - def __init__(self, fn, role, lambda_args=(), **kw): + def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()): self.lambda_args = lambda_args - self.coerce_kw = kw - super(DeferredLambdaElement, self).__init__(fn, role, **kw) + super(DeferredLambdaElement, self).__init__(fn, role, opts) def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) @@ -297,10 +346,30 @@ class DeferredLambdaElement(LambdaElement): tracker_fn = self._rec.tracker_instrumented_fn expr = tracker_fn(*lambda_args) - expr = coercions.expect(self.role, expr, **self.coerce_kw) - - if self._resolved_bindparams: - expr = self._setup_binds_for_tracked_expr(expr) + expr = coercions.expect(self.role, expr) + + expr = self._setup_binds_for_tracked_expr(expr) + + # this validation is getting very close, but not quite, to achieving + # #5767. The problem is if the base lambda uses an unnamed column + # as is very common with mixins, the parameter name is different + # and it produces a false positive; that is, for the documented case + # that is exactly what people will be doing, it doesn't work, so + # I'm not really sure how to handle this right now. + # expected_binds = [ + # b._orig_key + # for b in self._rec.expr._generate_cache_key()[1] + # if b.required + # ] + # got_binds = [ + # b._orig_key for b in expr._generate_cache_key()[1] if b.required + # ] + # if expected_binds != got_binds: + # raise exc.InvalidRequestError( + # "Lambda callable at %s produced a different set of bound " + # "parameters than its original run: %s" + # % (self.fn.__code__, ", ".join(got_binds)) + # ) # TODO: TEST TEST TEST, this is very out there for deferred_copy_internals in self._transforms: @@ -312,7 +381,9 @@ class DeferredLambdaElement(LambdaElement): self, clone=_clone, deferred_copy_internals=None, **kw ): super(DeferredLambdaElement, self)._copy_internals( - clone=clone, deferred_copy_internals=deferred_copy_internals, **kw + clone=clone, + deferred_copy_internals=deferred_copy_internals, # **kw + opts=kw, ) # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know @@ -347,33 +418,60 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ - def __init__(self, fn, parent_lambda, **kw): - self._default_kw = default_kw = {} - global_track_bound_values = kw.pop("global_track_bound_values", None) - if global_track_bound_values is not None: - default_kw["track_bound_values"] = global_track_bound_values - kw["track_bound_values"] = global_track_bound_values + def __add__(self, other): + return self.add_criteria(other) - if "lambda_cache" in kw: - default_kw["lambda_cache"] = kw["lambda_cache"] + def add_criteria( + self, + other, + enable_tracking=True, + track_on=None, + track_closure_variables=True, + track_bound_values=True, + ): + """Add new criteria to this :class:`_sql.StatementLambdaElement`. + + E.g.:: + + >>> def my_stmt(parameter): + ... stmt = lambda_stmt( + ... lambda: select(table.c.x, table.c.y), + ... ) + ... stmt = stmt.add_criteria( + ... lambda: table.c.x > parameter + ... ) + ... return stmt + + The :meth:`_sql.StatementLambdaElement.add_criteria` method is + equivalent to using the Python addition operator to add a new + lambda, except that additional arguments may be added including + ``track_closure_values`` and ``track_on``:: + + >>> def my_stmt(self, foo): + ... stmt = lambda_stmt( + ... lambda: select(func.max(foo.x, foo.y)), + ... track_closure_variables=False + ... ) + ... stmt = stmt.add_criteria( + ... lambda: self.where_criteria, + ... track_on=[self] + ... ) + ... return stmt + + See :func:`_sql.lambda_stmt` for a description of the parameters + accepted. - super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw) + """ - def __add__(self, other): - return LinkedLambdaElement( - other, parent_lambda=self, **self._default_kw + opts = self.opts + dict( + enable_tracking=enable_tracking, + track_closure_variables=track_closure_variables, + global_track_bound_values=self.opts.global_track_bound_values, + track_on=track_on, + track_bound_values=track_bound_values, ) - def add_criteria(self, other, **kw): - if self._default_kw: - if kw: - default_kw = self._default_kw.copy() - default_kw.update(kw) - kw = default_kw - else: - kw = self._default_kw - - return LinkedLambdaElement(other, parent_lambda=self, **kw) + return LinkedLambdaElement(other, parent_lambda=self, opts=opts) def _execute_on_connection( self, connection, multiparams, params, execution_options @@ -461,14 +559,13 @@ class LinkedLambdaElement(StatementLambdaElement): role = None - def __init__(self, fn, parent_lambda, **kw): - self._default_kw = parent_lambda._default_kw - + def __init__(self, fn, parent_lambda, opts): + self.opts = opts self.fn = fn self.parent_lambda = parent_lambda self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) - self._retrieve_tracker_rec(fn, self, kw) + self._retrieve_tracker_rec(fn, self, opts) self._propagate_attrs = parent_lambda._propagate_attrs def _invoke_user_fn(self, fn, *arg): @@ -497,20 +594,17 @@ class AnalyzedCode(object): ) return analyzed - def __init__( - self, - fn, - lambda_element, - lambda_kw, - track_bound_values=True, - enable_tracking=True, - track_on=None, - ): + def __init__(self, fn, lambda_element, opts): closure = fn.__closure__ - self.track_closure_variables = not track_on + self.track_bound_values = ( + opts.track_bound_values and opts.global_track_bound_values + ) + enable_tracking = opts.enable_tracking + track_on = opts.track_on + track_closure_variables = opts.track_closure_variables - self.track_bound_values = track_bound_values + self.track_closure_variables = track_closure_variables and not track_on # a list of callables generated from _bound_parameter_getter_* # functions. Each of these uses a PyWrapper object to retrieve @@ -533,7 +627,7 @@ class AnalyzedCode(object): if closure: self._init_closure(fn) - self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw) + self._setup_additional_closure_trackers(fn, lambda_element, opts) def _init_track_on(self, track_on): self.closure_trackers.extend( @@ -590,13 +684,11 @@ class AnalyzedCode(object): if track_closure_variables: closure_trackers.append( self._cache_key_getter_closure_variable( - closure_index, cell.cell_contents + fn, fv, closure_index, cell.cell_contents ) ) - def _setup_additional_closure_trackers( - self, fn, lambda_element, lambda_kw - ): + def _setup_additional_closure_trackers(self, fn, lambda_element, opts): # an additional step is to actually run the function, then # go through the PyWrapper objects that were set up to catch a bound # parameter. then if they *didn't* make a param, oh they're another @@ -607,7 +699,6 @@ class AnalyzedCode(object): self, lambda_element, None, - lambda_kw, fn, ) @@ -616,7 +707,7 @@ class AnalyzedCode(object): for pywrapper in analyzed_function.closure_pywrappers: if not pywrapper._sa__has_param: closure_trackers.append( - self._cache_key_getter_tracked_literal(pywrapper) + self._cache_key_getter_tracked_literal(fn, pywrapper) ) @classmethod @@ -625,7 +716,7 @@ class AnalyzedCode(object): if is_clause_element: while not isinstance( - element, (elements.ClauseElement, schema.SchemaItem) + element, (elements.ClauseElement, schema.SchemaItem, type) ): try: element = element.__clause_element__() @@ -688,17 +779,25 @@ class AnalyzedCode(object): """ if isinstance(elem, traversals.HasCacheKey): - def get(closure, kw, anon_map, bindparams): - return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams) + def get(closure, opts, anon_map, bindparams): + return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) else: - def get(closure, kw, anon_map, bindparams): - return kw["track_on"][idx] + def get(closure, opts, anon_map, bindparams): + return opts.track_on[idx] return get - def _cache_key_getter_closure_variable(self, idx, cell_contents): + def _cache_key_getter_closure_variable( + self, + fn, + variable_name, + idx, + cell_contents, + use_clause_element=False, + use_inspect=False, + ): """Return a getter that will extend a cache key with new entries from the ``__closure__`` collection of a particular lambda. @@ -706,29 +805,90 @@ class AnalyzedCode(object): if isinstance(cell_contents, traversals.HasCacheKey): - def get(closure, kw, anon_map, bindparams): - return closure[idx].cell_contents._gen_cache_key( - anon_map, bindparams - ) + def get(closure, opts, anon_map, bindparams): + + obj = closure[idx].cell_contents + if use_inspect: + obj = inspection.inspect(obj) + elif use_clause_element: + while hasattr(obj, "__clause_element__"): + if not getattr(obj, "is_clause_element", False): + obj = obj.__clause_element__() + + return obj._gen_cache_key(anon_map, bindparams) elif isinstance(cell_contents, types.FunctionType): - def get(closure, kw, anon_map, bindparams): + def get(closure, opts, anon_map, bindparams): return closure[idx].cell_contents.__code__ - elif cell_contents.__hash__ is None: - # this covers dict, etc. - def get(closure, kw, anon_map, bindparams): - return () + elif isinstance(cell_contents, collections_abc.Sequence): + + def get(closure, opts, anon_map, bindparams): + contents = closure[idx].cell_contents + + try: + return tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in contents + ) + except AttributeError as ae: + self._raise_for_uncacheable_closure_variable( + variable_name, fn, from_=ae + ) else: + # if the object is a mapped class or aliased class, or some + # other object in the ORM realm of things like that, imitate + # the logic used in coercions.expect() to roll it down to the + # SQL element + element = cell_contents + is_clause_element = False + while hasattr(element, "__clause_element__"): + is_clause_element = True + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break - def get(closure, kw, anon_map, bindparams): - return closure[idx].cell_contents + if not is_clause_element: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + return self._cache_key_getter_closure_variable( + fn, variable_name, idx, insp, use_inspect=True + ) + else: + return self._cache_key_getter_closure_variable( + fn, variable_name, idx, element, use_clause_element=True + ) + + self._raise_for_uncacheable_closure_variable(variable_name, fn) return get - def _cache_key_getter_tracked_literal(self, pytracker): + def _raise_for_uncacheable_closure_variable( + self, variable_name, fn, from_=None + ): + util.raise_( + exc.InvalidRequestError( + "Closure variable named '%s' inside of lambda callable %s " + "does not refer to a cachable SQL element, and also does not " + "appear to be serving as a SQL literal bound value based on " + "the default " + "SQL expression returned by the function. This variable " + "needs to remain outside the scope of a SQL-generating lambda " + "so that a proper cache key may be generated from the " + "lambda's state. Evaluate this variable outside of the " + "lambda, set track_on=[<elements>] to explicitly select " + "closure elements to track, or set " + "track_closure_variables=False to exclude " + "closure variables from being part of the cache key." + % (variable_name, fn.__code__), + ), + from_=from_, + ) + + def _cache_key_getter_tracked_literal(self, fn, pytracker): """Return a getter that will extend a cache key with new entries from the ``__closure__`` collection of a particular lambda. @@ -741,33 +901,11 @@ class AnalyzedCode(object): elem = pytracker._sa__to_evaluate closure_index = pytracker._sa__closure_index + variable_name = pytracker._sa__name - if isinstance(elem, set): - raise exc.ArgumentError( - "Can't create a cache key for lambda closure variable " - '"%s" because it\'s a set. try using a list' - % pytracker._sa__name - ) - - elif isinstance(elem, list): - - def get(closure, kw, anon_map, bindparams): - return tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in closure[closure_index].cell_contents - ) - - elif elem.__hash__ is None: - # this covers dict, etc. - def get(closure, kw, anon_map, bindparams): - return () - - else: - - def get(closure, kw, anon_map, bindparams): - return closure[closure_index].cell_contents - - return get + return self._cache_key_getter_closure_variable( + fn, variable_name, closure_index, elem + ) class AnalyzedFunction(object): @@ -789,7 +927,6 @@ class AnalyzedFunction(object): analyzed_code, lambda_element, apply_propagate_attrs, - kw, fn, ): self.analyzed_code = analyzed_code @@ -799,7 +936,7 @@ class AnalyzedFunction(object): self._instrument_and_run_function(lambda_element) - self._coerce_expression(lambda_element, apply_propagate_attrs, kw) + self._coerce_expression(lambda_element, apply_propagate_attrs) def _instrument_and_run_function(self, lambda_element): analyzed_code = self.analyzed_code @@ -832,13 +969,19 @@ class AnalyzedFunction(object): if closure_index is not None: value = closure[closure_index].cell_contents new_closure[name] = bind = PyWrapper( - name, value, closure_index=closure_index + fn, + name, + value, + closure_index=closure_index, + track_bound_values=( + self.analyzed_code.track_bound_values + ), ) if track_closure_variables: closure_pywrappers.append(bind) else: value = fn.__globals__[name] - new_globals[name] = bind = PyWrapper(name, value) + new_globals[name] = bind = PyWrapper(fn, name, value) # rewrite the original fn. things that look like they will # become bound parameters are wrapped in a PyWrapper. @@ -863,7 +1006,7 @@ class AnalyzedFunction(object): # variable. self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) - def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw): + def _coerce_expression(self, lambda_element, apply_propagate_attrs): """Run the tracker-generated expression through coercion rules. After the user-defined lambda has been invoked to produce a statement @@ -882,7 +1025,6 @@ class AnalyzedFunction(object): lambda_element.role, sub_expr, apply_propagate_attrs=apply_propagate_attrs, - **kw ) for sub_expr in expr ] @@ -892,7 +1034,6 @@ class AnalyzedFunction(object): lambda_element.role, expr, apply_propagate_attrs=apply_propagate_attrs, - **kw ) self.is_sequence = False else: @@ -956,7 +1097,16 @@ class PyWrapper(ColumnOperators): """ - def __init__(self, name, to_evaluate, closure_index=None, getter=None): + def __init__( + self, + fn, + name, + to_evaluate, + closure_index=None, + getter=None, + track_bound_values=True, + ): + self.fn = fn self._name = name self._to_evaluate = to_evaluate self._param = None @@ -964,28 +1114,35 @@ class PyWrapper(ColumnOperators): self._bind_paths = {} self._getter = getter self._closure_index = closure_index + self.track_bound_values = track_bound_values def __call__(self, *arg, **kw): elem = object.__getattribute__(self, "_to_evaluate") value = elem(*arg, **kw) - if coercions._deep_is_literal(value) and not isinstance( - # TODO: coverage where an ORM option or similar is here - value, - traversals.HasCacheKey, + if ( + self._sa_track_bound_values + and coercions._deep_is_literal(value) + and not isinstance( + # TODO: coverage where an ORM option or similar is here + value, + traversals.HasCacheKey, + ) ): - # TODO: we can instead scan the arguments and make sure they - # are all Python literals - - # TODO: coverage name = object.__getattribute__(self, "_name") raise exc.InvalidRequestError( "Can't invoke Python callable %s() inside of lambda " - "expression argument; lambda cache keys should not call " - "regular functions since the caching " - "system does not track the values of the arguments passed " - "to the functions. Call the function outside of the lambda " - "and assign to a local variable that is used in the lambda." - % (name) + "expression argument at %s; lambda SQL constructs should " + "not invoke functions from closure variables to produce " + "literal values since the " + "lambda SQL system normally extracts bound values without " + "actually " + "invoking the lambda or any functions within it. Call the " + "function outside of the " + "lambda and assign to a local variable that is used in the " + "lambda as a closure variable, or set " + "track_bound_values=False if the return value of this " + "function is used in some other way other than a SQL bound " + "value." % (name, self._sa_fn.__code__) ) else: return value @@ -1018,6 +1175,14 @@ class PyWrapper(ColumnOperators): param.type = type_api._resolve_value_to_type(to_evaluate) return param._with_value(to_evaluate, maintain_key=True) + def __bool__(self): + to_evaluate = object.__getattribute__(self, "_to_evaluate") + return bool(to_evaluate) + + def __nonzero__(self): + to_evaluate = object.__getattribute__(self, "_to_evaluate") + return bool(to_evaluate) + def __getattribute__(self, key): if key.startswith("_sa_"): return object.__getattribute__(self, key[4:]) @@ -1026,6 +1191,7 @@ class PyWrapper(ColumnOperators): "operate", "reverse_operate", "__class__", + "__dict__", ): return object.__getattribute__(self, key) @@ -1064,8 +1230,10 @@ class PyWrapper(ColumnOperators): elem = object.__getattribute__(self, "_to_evaluate") value = getter(elem) - if coercions._deep_is_literal(value): - wrapper = PyWrapper(key, value, getter=getter) + rolled_down_value = AnalyzedCode._roll_down_to_literal(value) + + if coercions._deep_is_literal(rolled_down_value): + wrapper = PyWrapper(self._sa_fn, key, value, getter=getter) bind_paths[bind_path_key] = wrapper return wrapper else: |