summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-02-09 21:16:53 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2012-02-09 21:16:53 -0500
commitbc45fa350a02da5f24d866078abed471cd98f15b (patch)
tree2607af2197e003fdc735c020207d4f234d718fee /lib/sqlalchemy/sql/util.py
parent91f4109dc3ec49686ba2393eb6b7bd9bb5b95fb3 (diff)
downloadsqlalchemy-bc45fa350a02da5f24d866078abed471cd98f15b.tar.gz
- got m2m, local_remote_pairs, etc. working
- using new traversal that returns the product of both sides of a binary, starting to work with (a+b) == (c+d) types of joins. primaryjoins on functions working - annotations working, including reversing local/remote when doing backref
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py76
1 files changed, 70 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index f0509c16f..9a45a5777 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -62,6 +62,61 @@ def find_join_source(clauses, join_to):
else:
return None, None
+
+def visit_binary_product(fn, expr):
+ """Produce a traversal of the given expression, delivering
+ column comparisons to the given function.
+
+ The function is of the form::
+
+ def my_fn(binary, left, right)
+
+ For each binary expression located which has a
+ comparison operator, the product of "left" and
+ "right" will be delivered to that function,
+ in terms of that binary.
+
+ Hence an expression like::
+
+ and_(
+ (a + b) == q + func.sum(e + f),
+ j == r
+ )
+
+ would have the traversal::
+
+ a <eq> q
+ a <eq> e
+ a <eq> f
+ b <eq> q
+ b <eq> e
+ b <eq> f
+ j <eq> r
+
+ That is, every combination of "left" and
+ "right" that doesn't further contain
+ a binary comparison is passed as pairs.
+
+ """
+ stack = []
+ def visit(element):
+ if element.__visit_name__ == 'binary' and \
+ operators.is_comparison(element.operator):
+ stack.insert(0, element)
+ for l in visit(element.left):
+ for r in visit(element.right):
+ fn(stack[0], l, r)
+ stack.pop(0)
+ for elem in element.get_children():
+ visit(elem)
+ else:
+ if isinstance(element, expression.ColumnClause):
+ yield element
+ for elem in element.get_children():
+ for e in visit(elem):
+ yield e
+ list(visit(expr))
+
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
include_selects=False, include_crud=False):
@@ -357,13 +412,22 @@ class Annotated(object):
def _annotate(self, values):
_values = self._annotations.copy()
_values.update(values)
+ return self._with_annotations(_values)
+
+ def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
- clone._annotations = _values
+ clone._annotations = values
return clone
- def _deannotate(self):
- return self.__element
+ def _deannotate(self, values=None):
+ if values is None:
+ return self.__element
+ else:
+ _values = self._annotations.copy()
+ for v in values:
+ _values.pop(v, None)
+ return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@@ -426,11 +490,11 @@ def _deep_annotate(element, annotations, exclude=None):
element = clone(element)
return element
-def _deep_deannotate(element):
- """Deep copy the given element, removing all annotations."""
+def _deep_deannotate(element, values=None):
+ """Deep copy the given element, removing annotations."""
def clone(elem):
- elem = elem._deannotate()
+ elem = elem._deannotate(values=values)
elem._copy_internals(clone=clone)
return elem