diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-02-09 21:16:53 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-02-09 21:16:53 -0500 |
commit | bc45fa350a02da5f24d866078abed471cd98f15b (patch) | |
tree | 2607af2197e003fdc735c020207d4f234d718fee /lib/sqlalchemy/sql/util.py | |
parent | 91f4109dc3ec49686ba2393eb6b7bd9bb5b95fb3 (diff) | |
download | sqlalchemy-bc45fa350a02da5f24d866078abed471cd98f15b.tar.gz |
- got m2m, local_remote_pairs, etc. working
- using new traversal that returns the product of both sides
of a binary, starting to work with (a+b) == (c+d) types of joins.
primaryjoins on functions working
- annotations working, including reversing local/remote when
doing backref
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 76 |
1 files changed, 70 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f0509c16f..9a45a5777 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -62,6 +62,61 @@ def find_join_source(clauses, join_to): else: return None, None + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a <eq> q + a <eq> e + a <eq> f + b <eq> q + b <eq> e + b <eq> f + j <eq> r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + def visit(element): + if element.__visit_name__ == 'binary' and \ + operators.is_comparison(element.operator): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, expression.ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + list(visit(expr)) + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False, include_crud=False): @@ -357,13 +412,22 @@ class Annotated(object): def _annotate(self, values): _values = self._annotations.copy() _values.update(values) + return self._with_annotations(_values) + + def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() - clone._annotations = _values + clone._annotations = values return clone - def _deannotate(self): - return self.__element + def _deannotate(self, values=None): + if values is None: + return self.__element + else: + _values = self._annotations.copy() + for v in values: + _values.pop(v, None) + return self._with_annotations(_values) def _compiler_dispatch(self, visitor, **kw): return self.__element.__class__._compiler_dispatch(self, visitor, **kw) @@ -426,11 +490,11 @@ def _deep_annotate(element, annotations, exclude=None): element = clone(element) return element -def _deep_deannotate(element): - """Deep copy the given element, removing all annotations.""" +def _deep_deannotate(element, values=None): + """Deep copy the given element, removing annotations.""" def clone(elem): - elem = elem._deannotate() + elem = elem._deannotate(values=values) elem._copy_internals(clone=clone) return elem |