summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/baked.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/baked.py')
-rw-r--r--lib/sqlalchemy/ext/baked.py47
1 files changed, 42 insertions, 5 deletions
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
index c0fe963ac..8cae6e24b 100644
--- a/lib/sqlalchemy/ext/baked.py
+++ b/lib/sqlalchemy/ext/baked.py
@@ -261,12 +261,13 @@ class Result(object):
against a target :class:`.Session`, and is then invoked for results.
"""
- __slots__ = 'bq', 'session', '_params'
+ __slots__ = 'bq', 'session', '_params', '_post_criteria'
def __init__(self, bq, session):
self.bq = bq
self.session = session
self._params = {}
+ self._post_criteria = []
def params(self, *args, **kw):
"""Specify parameters to be replaced into the string SQL statement."""
@@ -280,8 +281,37 @@ class Result(object):
self._params.update(kw)
return self
+ def _using_post_criteria(self, fns):
+ if fns:
+ self._post_criteria.extend(fns)
+ return self
+
+ def with_post_criteria(self, fn):
+ """Add a criteria function that will be applied post-cache.
+
+ This adds a function that will be run against the
+ :class:`.Query` object after it is retrieved from the
+ cache. Functions here can be used to alter the query in ways
+ that **do not affect the SQL output**, such as execution options
+ and shard identifiers (when using a shard-enabled query object)
+
+ .. warning:: :meth:`.Result.with_post_criteria` functions are applied
+ to the :class:`.Query` object **after** the query's SQL statement
+ object has been retrieved from the cache. Any operations here
+ which intend to modify the SQL should ensure that
+ :meth:`.BakedQuery.spoil` was called first.
+
+ .. versionadded:: 1.2
+
+
+ """
+ return self._using_post_criteria([fn])
+
def _as_query(self):
- return self.bq._as_query(self.session).params(self._params)
+ q = self.bq._as_query(self.session).params(self._params)
+ for fn in self._post_criteria:
+ q = fn(q)
+ return q
def __str__(self):
return str(self._as_query())
@@ -304,8 +334,11 @@ class Result(object):
context.statement.use_labels = True
if context.autoflush and not context.populate_existing:
self.session._autoflush()
- return context.query.params(self._params).\
- with_session(self.session)._execute_and_instances(context)
+ q = context.query.params(self._params).with_session(self.session)
+ for fn in self._post_criteria:
+ q = fn(q)
+
+ return q._execute_and_instances(context)
def count(self):
"""return the 'count'.
@@ -348,7 +381,9 @@ class Result(object):
"""
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
- ret = list(bq.for_session(self.session).params(self._params))
+ ret = list(
+ bq.for_session(self.session).params(self._params).
+ _using_post_criteria(self._post_criteria))
if len(ret) > 0:
return ret[0]
else:
@@ -435,6 +470,8 @@ class Result(object):
_lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False)
q._criterion = _lcl_get_clause
+ for fn in self._post_criteria:
+ q = fn(q)
return q
# cache the query against a key that includes