summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/lambdas.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/lambdas.py')
-rw-r--r--lib/sqlalchemy/sql/lambdas.py456
1 files changed, 312 insertions, 144 deletions
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index aafdda4ce..3f0ca477e 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -19,6 +19,7 @@ from . import traversals
from . import type_api
from . import visitors
from .base import _clone
+from .base import Options
from .operators import ColumnOperators
from .. import exc
from .. import inspection
@@ -28,7 +29,24 @@ from ..util import collections_abc
_closure_per_cache_key = util.LRUCache(1000)
-def lambda_stmt(lmb, **opts):
+class LambdaOptions(Options):
+ enable_tracking = True
+ track_closure_variables = True
+ track_on = None
+ global_track_bound_values = True
+ track_bound_values = True
+ lambda_cache = None
+
+
+def lambda_stmt(
+ lmb,
+ enable_tracking=True,
+ track_closure_variables=True,
+ track_on=None,
+ global_track_bound_values=True,
+ track_bound_values=True,
+ lambda_cache=None,
+):
"""Produce a SQL statement that is cached as a lambda.
The Python code object within the lambda is scanned for both Python
@@ -49,6 +67,29 @@ def lambda_stmt(lmb, **opts):
.. versionadded:: 1.4
+ :param lmb: a Python function, typically a lambda, which takes no arguments
+ and returns a SQL expression construct
+ :param enable_tracking: when False, all scanning of the given lambda for
+ changes in closure variables or bound parameters is disabled. Use for
+ a lambda that produces the identical results in all cases with no
+ parameterization.
+ :param track_closure_variables: when False, changes in closure variables
+ within the lambda will not be scanned. Use for a lambda where the
+ state of its closure variables will never change the SQL structure
+ returned by the lambda.
+ :param track_bound_values: when False, bound parameter tracking will
+ be disabled for the given lambda. Use for a lambda that either does
+ not produce any bound values, or where the initial bound values never
+ change.
+ :param global_track_bound_values: when False, bound parameter tracking
+ will be disabled for the entire statement including additional links
+ added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
+ :param lambda_cache: a dictionary or other mapping-like object where
+ information about the lambda's Python code as well as the tracked closure
+ variables in the lambda itself will be stored. Defaults
+ to a global LRU cache. This cache is independent of the "compiled_cache"
+ used by the :class:`_engine.Connection` object.
+
.. seealso::
:ref:`engine_lambda_caching`
@@ -56,7 +97,18 @@ def lambda_stmt(lmb, **opts):
"""
- return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts)
+ return StatementLambdaElement(
+ lmb,
+ roles.CoerceTextStatementRole,
+ LambdaOptions(
+ enable_tracking=enable_tracking,
+ track_on=track_on,
+ track_closure_variables=track_closure_variables,
+ global_track_bound_values=global_track_bound_values,
+ track_bound_values=track_bound_values,
+ lambda_cache=lambda_cache,
+ ),
+ )
class LambdaElement(elements.ClauseElement):
@@ -94,38 +146,39 @@ class LambdaElement(elements.ClauseElement):
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
- def __init__(self, fn, role, apply_propagate_attrs=None, **kw):
+ def __init__(
+ self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
+ ):
self.fn = fn
self.role = role
self.tracker_key = (fn.__code__,)
+ self.opts = opts
if apply_propagate_attrs is None and (
role is roles.CoerceTextStatementRole
):
apply_propagate_attrs = self
- rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw)
+ rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
if apply_propagate_attrs is not None:
propagate_attrs = rec.propagate_attrs
if propagate_attrs:
apply_propagate_attrs._propagate_attrs = propagate_attrs
- def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw):
- lambda_cache = kw.get("lambda_cache", _closure_per_cache_key)
+ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
+ lambda_cache = opts.lambda_cache
+ if lambda_cache is None:
+ lambda_cache = _closure_per_cache_key
tracker_key = self.tracker_key
fn = self.fn
closure = fn.__closure__
-
tracker = AnalyzedCode.get(
fn,
self,
- kw,
- track_bound_values=kw.get("track_bound_values", True),
- enable_tracking=kw.get("enable_tracking", True),
- track_on=kw.get("track_on", None),
+ opts,
)
self._resolved_bindparams = bindparams = []
@@ -133,10 +186,11 @@ class LambdaElement(elements.ClauseElement):
anon_map = traversals.anon_map()
cache_key = tuple(
[
- getter(closure, kw, anon_map, bindparams)
+ getter(closure, opts, anon_map, bindparams)
for getter in tracker.closure_trackers
]
)
+
if self.parent_lambda is not None:
cache_key = self.parent_lambda.closure_cache_key + cache_key
@@ -148,9 +202,7 @@ class LambdaElement(elements.ClauseElement):
rec = None
if rec is None:
- rec = AnalyzedFunction(
- tracker, self, apply_propagate_attrs, kw, fn
- )
+ rec = AnalyzedFunction(tracker, self, apply_propagate_attrs, fn)
rec.closure_bindparams = bindparams
lambda_cache[tracker_key + cache_key] = rec
else:
@@ -213,14 +265,13 @@ class LambdaElement(elements.ClauseElement):
bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
def replace(thing):
- if (
- isinstance(thing, elements.BindParameter)
- and thing.key in bindparam_lookup
- ):
- bind = bindparam_lookup[thing.key]
- if thing.expanding:
- bind.expanding = True
- return bind
+ if isinstance(thing, elements.BindParameter):
+
+ if thing.key in bindparam_lookup:
+ bind = bindparam_lookup[thing.key]
+ if thing.expanding:
+ bind.expanding = True
+ return bind
if self._rec.is_sequence:
expr = [
@@ -268,7 +319,6 @@ class LambdaElement(elements.ClauseElement):
if self._resolved_bindparams:
bindparams.extend(self._resolved_bindparams)
-
return cache_key
def _invoke_user_fn(self, fn, *arg):
@@ -285,10 +335,9 @@ class DeferredLambdaElement(LambdaElement):
"""
- def __init__(self, fn, role, lambda_args=(), **kw):
+ def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
self.lambda_args = lambda_args
- self.coerce_kw = kw
- super(DeferredLambdaElement, self).__init__(fn, role, **kw)
+ super(DeferredLambdaElement, self).__init__(fn, role, opts)
def _invoke_user_fn(self, fn, *arg):
return fn(*self.lambda_args)
@@ -297,10 +346,30 @@ class DeferredLambdaElement(LambdaElement):
tracker_fn = self._rec.tracker_instrumented_fn
expr = tracker_fn(*lambda_args)
- expr = coercions.expect(self.role, expr, **self.coerce_kw)
-
- if self._resolved_bindparams:
- expr = self._setup_binds_for_tracked_expr(expr)
+ expr = coercions.expect(self.role, expr)
+
+ expr = self._setup_binds_for_tracked_expr(expr)
+
+ # this validation is getting very close, but not quite, to achieving
+ # #5767. The problem is if the base lambda uses an unnamed column
+ # as is very common with mixins, the parameter name is different
+ # and it produces a false positive; that is, for the documented case
+ # that is exactly what people will be doing, it doesn't work, so
+ # I'm not really sure how to handle this right now.
+ # expected_binds = [
+ # b._orig_key
+ # for b in self._rec.expr._generate_cache_key()[1]
+ # if b.required
+ # ]
+ # got_binds = [
+ # b._orig_key for b in expr._generate_cache_key()[1] if b.required
+ # ]
+ # if expected_binds != got_binds:
+ # raise exc.InvalidRequestError(
+ # "Lambda callable at %s produced a different set of bound "
+ # "parameters than its original run: %s"
+ # % (self.fn.__code__, ", ".join(got_binds))
+ # )
# TODO: TEST TEST TEST, this is very out there
for deferred_copy_internals in self._transforms:
@@ -312,7 +381,9 @@ class DeferredLambdaElement(LambdaElement):
self, clone=_clone, deferred_copy_internals=None, **kw
):
super(DeferredLambdaElement, self)._copy_internals(
- clone=clone, deferred_copy_internals=deferred_copy_internals, **kw
+ clone=clone,
+ deferred_copy_internals=deferred_copy_internals, # **kw
+ opts=kw,
)
# TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
@@ -347,33 +418,60 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
"""
- def __init__(self, fn, parent_lambda, **kw):
- self._default_kw = default_kw = {}
- global_track_bound_values = kw.pop("global_track_bound_values", None)
- if global_track_bound_values is not None:
- default_kw["track_bound_values"] = global_track_bound_values
- kw["track_bound_values"] = global_track_bound_values
+ def __add__(self, other):
+ return self.add_criteria(other)
- if "lambda_cache" in kw:
- default_kw["lambda_cache"] = kw["lambda_cache"]
+ def add_criteria(
+ self,
+ other,
+ enable_tracking=True,
+ track_on=None,
+ track_closure_variables=True,
+ track_bound_values=True,
+ ):
+ """Add new criteria to this :class:`_sql.StatementLambdaElement`.
+
+ E.g.::
+
+ >>> def my_stmt(parameter):
+ ... stmt = lambda_stmt(
+ ... lambda: select(table.c.x, table.c.y),
+ ... )
+ ... stmt = stmt.add_criteria(
+ ... lambda: table.c.x > parameter
+ ... )
+ ... return stmt
+
+ The :meth:`_sql.StatementLambdaElement.add_criteria` method is
+ equivalent to using the Python addition operator to add a new
+ lambda, except that additional arguments may be added including
+ ``track_closure_values`` and ``track_on``::
+
+ >>> def my_stmt(self, foo):
+ ... stmt = lambda_stmt(
+ ... lambda: select(func.max(foo.x, foo.y)),
+ ... track_closure_variables=False
+ ... )
+ ... stmt = stmt.add_criteria(
+ ... lambda: self.where_criteria,
+ ... track_on=[self]
+ ... )
+ ... return stmt
+
+ See :func:`_sql.lambda_stmt` for a description of the parameters
+ accepted.
- super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw)
+ """
- def __add__(self, other):
- return LinkedLambdaElement(
- other, parent_lambda=self, **self._default_kw
+ opts = self.opts + dict(
+ enable_tracking=enable_tracking,
+ track_closure_variables=track_closure_variables,
+ global_track_bound_values=self.opts.global_track_bound_values,
+ track_on=track_on,
+ track_bound_values=track_bound_values,
)
- def add_criteria(self, other, **kw):
- if self._default_kw:
- if kw:
- default_kw = self._default_kw.copy()
- default_kw.update(kw)
- kw = default_kw
- else:
- kw = self._default_kw
-
- return LinkedLambdaElement(other, parent_lambda=self, **kw)
+ return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
def _execute_on_connection(
self, connection, multiparams, params, execution_options
@@ -461,14 +559,13 @@ class LinkedLambdaElement(StatementLambdaElement):
role = None
- def __init__(self, fn, parent_lambda, **kw):
- self._default_kw = parent_lambda._default_kw
-
+ def __init__(self, fn, parent_lambda, opts):
+ self.opts = opts
self.fn = fn
self.parent_lambda = parent_lambda
self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
- self._retrieve_tracker_rec(fn, self, kw)
+ self._retrieve_tracker_rec(fn, self, opts)
self._propagate_attrs = parent_lambda._propagate_attrs
def _invoke_user_fn(self, fn, *arg):
@@ -497,20 +594,17 @@ class AnalyzedCode(object):
)
return analyzed
- def __init__(
- self,
- fn,
- lambda_element,
- lambda_kw,
- track_bound_values=True,
- enable_tracking=True,
- track_on=None,
- ):
+ def __init__(self, fn, lambda_element, opts):
closure = fn.__closure__
- self.track_closure_variables = not track_on
+ self.track_bound_values = (
+ opts.track_bound_values and opts.global_track_bound_values
+ )
+ enable_tracking = opts.enable_tracking
+ track_on = opts.track_on
+ track_closure_variables = opts.track_closure_variables
- self.track_bound_values = track_bound_values
+ self.track_closure_variables = track_closure_variables and not track_on
# a list of callables generated from _bound_parameter_getter_*
# functions. Each of these uses a PyWrapper object to retrieve
@@ -533,7 +627,7 @@ class AnalyzedCode(object):
if closure:
self._init_closure(fn)
- self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw)
+ self._setup_additional_closure_trackers(fn, lambda_element, opts)
def _init_track_on(self, track_on):
self.closure_trackers.extend(
@@ -590,13 +684,11 @@ class AnalyzedCode(object):
if track_closure_variables:
closure_trackers.append(
self._cache_key_getter_closure_variable(
- closure_index, cell.cell_contents
+ fn, fv, closure_index, cell.cell_contents
)
)
- def _setup_additional_closure_trackers(
- self, fn, lambda_element, lambda_kw
- ):
+ def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
# an additional step is to actually run the function, then
# go through the PyWrapper objects that were set up to catch a bound
# parameter. then if they *didn't* make a param, oh they're another
@@ -607,7 +699,6 @@ class AnalyzedCode(object):
self,
lambda_element,
None,
- lambda_kw,
fn,
)
@@ -616,7 +707,7 @@ class AnalyzedCode(object):
for pywrapper in analyzed_function.closure_pywrappers:
if not pywrapper._sa__has_param:
closure_trackers.append(
- self._cache_key_getter_tracked_literal(pywrapper)
+ self._cache_key_getter_tracked_literal(fn, pywrapper)
)
@classmethod
@@ -625,7 +716,7 @@ class AnalyzedCode(object):
if is_clause_element:
while not isinstance(
- element, (elements.ClauseElement, schema.SchemaItem)
+ element, (elements.ClauseElement, schema.SchemaItem, type)
):
try:
element = element.__clause_element__()
@@ -688,17 +779,25 @@ class AnalyzedCode(object):
"""
if isinstance(elem, traversals.HasCacheKey):
- def get(closure, kw, anon_map, bindparams):
- return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams)
+ def get(closure, opts, anon_map, bindparams):
+ return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
else:
- def get(closure, kw, anon_map, bindparams):
- return kw["track_on"][idx]
+ def get(closure, opts, anon_map, bindparams):
+ return opts.track_on[idx]
return get
- def _cache_key_getter_closure_variable(self, idx, cell_contents):
+ def _cache_key_getter_closure_variable(
+ self,
+ fn,
+ variable_name,
+ idx,
+ cell_contents,
+ use_clause_element=False,
+ use_inspect=False,
+ ):
"""Return a getter that will extend a cache key with new entries
from the ``__closure__`` collection of a particular lambda.
@@ -706,29 +805,90 @@ class AnalyzedCode(object):
if isinstance(cell_contents, traversals.HasCacheKey):
- def get(closure, kw, anon_map, bindparams):
- return closure[idx].cell_contents._gen_cache_key(
- anon_map, bindparams
- )
+ def get(closure, opts, anon_map, bindparams):
+
+ obj = closure[idx].cell_contents
+ if use_inspect:
+ obj = inspection.inspect(obj)
+ elif use_clause_element:
+ while hasattr(obj, "__clause_element__"):
+ if not getattr(obj, "is_clause_element", False):
+ obj = obj.__clause_element__()
+
+ return obj._gen_cache_key(anon_map, bindparams)
elif isinstance(cell_contents, types.FunctionType):
- def get(closure, kw, anon_map, bindparams):
+ def get(closure, opts, anon_map, bindparams):
return closure[idx].cell_contents.__code__
- elif cell_contents.__hash__ is None:
- # this covers dict, etc.
- def get(closure, kw, anon_map, bindparams):
- return ()
+ elif isinstance(cell_contents, collections_abc.Sequence):
+
+ def get(closure, opts, anon_map, bindparams):
+ contents = closure[idx].cell_contents
+
+ try:
+ return tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in contents
+ )
+ except AttributeError as ae:
+ self._raise_for_uncacheable_closure_variable(
+ variable_name, fn, from_=ae
+ )
else:
+ # if the object is a mapped class or aliased class, or some
+ # other object in the ORM realm of things like that, imitate
+ # the logic used in coercions.expect() to roll it down to the
+ # SQL element
+ element = cell_contents
+ is_clause_element = False
+ while hasattr(element, "__clause_element__"):
+ is_clause_element = True
+ if not getattr(element, "is_clause_element", False):
+ element = element.__clause_element__()
+ else:
+ break
- def get(closure, kw, anon_map, bindparams):
- return closure[idx].cell_contents
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, idx, insp, use_inspect=True
+ )
+ else:
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, idx, element, use_clause_element=True
+ )
+
+ self._raise_for_uncacheable_closure_variable(variable_name, fn)
return get
- def _cache_key_getter_tracked_literal(self, pytracker):
+ def _raise_for_uncacheable_closure_variable(
+ self, variable_name, fn, from_=None
+ ):
+ util.raise_(
+ exc.InvalidRequestError(
+ "Closure variable named '%s' inside of lambda callable %s "
+ "does not refer to a cachable SQL element, and also does not "
+ "appear to be serving as a SQL literal bound value based on "
+ "the default "
+ "SQL expression returned by the function. This variable "
+ "needs to remain outside the scope of a SQL-generating lambda "
+ "so that a proper cache key may be generated from the "
+ "lambda's state. Evaluate this variable outside of the "
+ "lambda, set track_on=[<elements>] to explicitly select "
+ "closure elements to track, or set "
+ "track_closure_variables=False to exclude "
+ "closure variables from being part of the cache key."
+ % (variable_name, fn.__code__),
+ ),
+ from_=from_,
+ )
+
+ def _cache_key_getter_tracked_literal(self, fn, pytracker):
"""Return a getter that will extend a cache key with new entries
from the ``__closure__`` collection of a particular lambda.
@@ -741,33 +901,11 @@ class AnalyzedCode(object):
elem = pytracker._sa__to_evaluate
closure_index = pytracker._sa__closure_index
+ variable_name = pytracker._sa__name
- if isinstance(elem, set):
- raise exc.ArgumentError(
- "Can't create a cache key for lambda closure variable "
- '"%s" because it\'s a set. try using a list'
- % pytracker._sa__name
- )
-
- elif isinstance(elem, list):
-
- def get(closure, kw, anon_map, bindparams):
- return tuple(
- elem._gen_cache_key(anon_map, bindparams)
- for elem in closure[closure_index].cell_contents
- )
-
- elif elem.__hash__ is None:
- # this covers dict, etc.
- def get(closure, kw, anon_map, bindparams):
- return ()
-
- else:
-
- def get(closure, kw, anon_map, bindparams):
- return closure[closure_index].cell_contents
-
- return get
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, closure_index, elem
+ )
class AnalyzedFunction(object):
@@ -789,7 +927,6 @@ class AnalyzedFunction(object):
analyzed_code,
lambda_element,
apply_propagate_attrs,
- kw,
fn,
):
self.analyzed_code = analyzed_code
@@ -799,7 +936,7 @@ class AnalyzedFunction(object):
self._instrument_and_run_function(lambda_element)
- self._coerce_expression(lambda_element, apply_propagate_attrs, kw)
+ self._coerce_expression(lambda_element, apply_propagate_attrs)
def _instrument_and_run_function(self, lambda_element):
analyzed_code = self.analyzed_code
@@ -832,13 +969,19 @@ class AnalyzedFunction(object):
if closure_index is not None:
value = closure[closure_index].cell_contents
new_closure[name] = bind = PyWrapper(
- name, value, closure_index=closure_index
+ fn,
+ name,
+ value,
+ closure_index=closure_index,
+ track_bound_values=(
+ self.analyzed_code.track_bound_values
+ ),
)
if track_closure_variables:
closure_pywrappers.append(bind)
else:
value = fn.__globals__[name]
- new_globals[name] = bind = PyWrapper(name, value)
+ new_globals[name] = bind = PyWrapper(fn, name, value)
# rewrite the original fn. things that look like they will
# become bound parameters are wrapped in a PyWrapper.
@@ -863,7 +1006,7 @@ class AnalyzedFunction(object):
# variable.
self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
- def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw):
+ def _coerce_expression(self, lambda_element, apply_propagate_attrs):
"""Run the tracker-generated expression through coercion rules.
After the user-defined lambda has been invoked to produce a statement
@@ -882,7 +1025,6 @@ class AnalyzedFunction(object):
lambda_element.role,
sub_expr,
apply_propagate_attrs=apply_propagate_attrs,
- **kw
)
for sub_expr in expr
]
@@ -892,7 +1034,6 @@ class AnalyzedFunction(object):
lambda_element.role,
expr,
apply_propagate_attrs=apply_propagate_attrs,
- **kw
)
self.is_sequence = False
else:
@@ -956,7 +1097,16 @@ class PyWrapper(ColumnOperators):
"""
- def __init__(self, name, to_evaluate, closure_index=None, getter=None):
+ def __init__(
+ self,
+ fn,
+ name,
+ to_evaluate,
+ closure_index=None,
+ getter=None,
+ track_bound_values=True,
+ ):
+ self.fn = fn
self._name = name
self._to_evaluate = to_evaluate
self._param = None
@@ -964,28 +1114,35 @@ class PyWrapper(ColumnOperators):
self._bind_paths = {}
self._getter = getter
self._closure_index = closure_index
+ self.track_bound_values = track_bound_values
def __call__(self, *arg, **kw):
elem = object.__getattribute__(self, "_to_evaluate")
value = elem(*arg, **kw)
- if coercions._deep_is_literal(value) and not isinstance(
- # TODO: coverage where an ORM option or similar is here
- value,
- traversals.HasCacheKey,
+ if (
+ self._sa_track_bound_values
+ and coercions._deep_is_literal(value)
+ and not isinstance(
+ # TODO: coverage where an ORM option or similar is here
+ value,
+ traversals.HasCacheKey,
+ )
):
- # TODO: we can instead scan the arguments and make sure they
- # are all Python literals
-
- # TODO: coverage
name = object.__getattribute__(self, "_name")
raise exc.InvalidRequestError(
"Can't invoke Python callable %s() inside of lambda "
- "expression argument; lambda cache keys should not call "
- "regular functions since the caching "
- "system does not track the values of the arguments passed "
- "to the functions. Call the function outside of the lambda "
- "and assign to a local variable that is used in the lambda."
- % (name)
+ "expression argument at %s; lambda SQL constructs should "
+ "not invoke functions from closure variables to produce "
+ "literal values since the "
+ "lambda SQL system normally extracts bound values without "
+ "actually "
+ "invoking the lambda or any functions within it. Call the "
+ "function outside of the "
+ "lambda and assign to a local variable that is used in the "
+ "lambda as a closure variable, or set "
+ "track_bound_values=False if the return value of this "
+ "function is used in some other way other than a SQL bound "
+ "value." % (name, self._sa_fn.__code__)
)
else:
return value
@@ -1018,6 +1175,14 @@ class PyWrapper(ColumnOperators):
param.type = type_api._resolve_value_to_type(to_evaluate)
return param._with_value(to_evaluate, maintain_key=True)
+ def __bool__(self):
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ return bool(to_evaluate)
+
+ def __nonzero__(self):
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ return bool(to_evaluate)
+
def __getattribute__(self, key):
if key.startswith("_sa_"):
return object.__getattribute__(self, key[4:])
@@ -1026,6 +1191,7 @@ class PyWrapper(ColumnOperators):
"operate",
"reverse_operate",
"__class__",
+ "__dict__",
):
return object.__getattribute__(self, key)
@@ -1064,8 +1230,10 @@ class PyWrapper(ColumnOperators):
elem = object.__getattribute__(self, "_to_evaluate")
value = getter(elem)
- if coercions._deep_is_literal(value):
- wrapper = PyWrapper(key, value, getter=getter)
+ rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
+
+ if coercions._deep_is_literal(rolled_down_value):
+ wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
bind_paths[bind_path_key] = wrapper
return wrapper
else: