summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py164
1 files changed, 128 insertions, 36 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 8d2b5ecfd..cb8359048 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -62,6 +62,65 @@ def find_join_source(clauses, join_to):
else:
return None, None
+
+def visit_binary_product(fn, expr):
+ """Produce a traversal of the given expression, delivering
+ column comparisons to the given function.
+
+ The function is of the form::
+
+ def my_fn(binary, left, right)
+
+ For each binary expression located which has a
+ comparison operator, the product of "left" and
+ "right" will be delivered to that function,
+ in terms of that binary.
+
+ Hence an expression like::
+
+ and_(
+ (a + b) == q + func.sum(e + f),
+ j == r
+ )
+
+ would have the traversal::
+
+ a <eq> q
+ a <eq> e
+ a <eq> f
+ b <eq> q
+ b <eq> e
+ b <eq> f
+ j <eq> r
+
+ That is, every combination of "left" and
+ "right" that doesn't further contain
+ a binary comparison is passed as pairs.
+
+ """
+ stack = []
+ def visit(element):
+ if isinstance(element, (expression._ScalarSelect)):
+ # we dont want to dig into correlated subqueries,
+ # those are just column elements by themselves
+ yield element
+ elif element.__visit_name__ == 'binary' and \
+ operators.is_comparison(element.operator):
+ stack.insert(0, element)
+ for l in visit(element.left):
+ for r in visit(element.right):
+ fn(stack[0], l, r)
+ stack.pop(0)
+ for elem in element.get_children():
+ visit(elem)
+ else:
+ if isinstance(element, expression.ColumnClause):
+ yield element
+ for elem in element.get_children():
+ for e in visit(elem):
+ yield e
+ list(visit(expr))
+
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
include_selects=False, include_crud=False):
@@ -225,7 +284,10 @@ def adapt_criterion_to_null(crit, nulls):
return visitors.cloned_traverse(crit, {}, {'binary':visit_binary})
-def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
+
+def join_condition(a, b, ignore_nonexistent_tables=False,
+ a_subset=None,
+ consider_as_foreign_keys=None):
"""create a join condition between two tables or selectables.
e.g.::
@@ -261,6 +323,9 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
for fk in sorted(
b.foreign_keys,
key=lambda fk:fk.parent._creation_order):
+ if consider_as_foreign_keys is not None and \
+ fk.parent not in consider_as_foreign_keys:
+ continue
try:
col = fk.get_referent(left)
except exc.NoReferenceError, nrte:
@@ -276,6 +341,9 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
for fk in sorted(
left.foreign_keys,
key=lambda fk:fk.parent._creation_order):
+ if consider_as_foreign_keys is not None and \
+ fk.parent not in consider_as_foreign_keys:
+ continue
try:
col = fk.get_referent(b)
except exc.NoReferenceError, nrte:
@@ -298,11 +366,11 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
"subquery using alias()?"
else:
hint = ""
- raise exc.ArgumentError(
+ raise exc.NoForeignKeysError(
"Can't find any foreign key relationships "
"between '%s' and '%s'.%s" % (a.description, b.description, hint))
elif len(constraints) > 1:
- raise exc.ArgumentError(
+ raise exc.AmbiguousForeignKeysError(
"Can't determine join between '%s' and '%s'; "
"tables have more than one foreign key "
"constraint relationship between them. "
@@ -356,13 +424,22 @@ class Annotated(object):
def _annotate(self, values):
_values = self._annotations.copy()
_values.update(values)
+ return self._with_annotations(_values)
+
+ def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
- clone._annotations = _values
+ clone._annotations = values
return clone
- def _deannotate(self):
- return self.__element
+ def _deannotate(self, values=None, clone=True):
+ if values is None:
+ return self.__element
+ else:
+ _values = self._annotations.copy()
+ for v in values:
+ _values.pop(v, None)
+ return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@@ -410,14 +487,8 @@ def _deep_annotate(element, annotations, exclude=None):
Elements within the exclude collection will be cloned but not annotated.
"""
- cloned = util.column_dict()
-
def clone(elem):
- # check if element is present in the exclude list.
- # take into account proxying relationships.
- if elem in cloned:
- return cloned[elem]
- elif exclude and \
+ if exclude and \
hasattr(elem, 'proxy_set') and \
elem.proxy_set.intersection(exclude):
newelem = elem._clone()
@@ -426,24 +497,32 @@ def _deep_annotate(element, annotations, exclude=None):
else:
newelem = elem
newelem._copy_internals(clone=clone)
- cloned[elem] = newelem
return newelem
if element is not None:
element = clone(element)
return element
-def _deep_deannotate(element):
- """Deep copy the given element, removing all annotations."""
+def _deep_deannotate(element, values=None):
+ """Deep copy the given element, removing annotations."""
cloned = util.column_dict()
def clone(elem):
- if elem not in cloned:
- newelem = elem._deannotate()
+ # if a values dict is given,
+ # the elem must be cloned each time it appears,
+ # as there may be different annotations in source
+ # elements that are remaining. if totally
+ # removing all annotations, can assume the same
+ # slate...
+ if values or elem not in cloned:
+ newelem = elem._deannotate(values=values, clone=True)
newelem._copy_internals(clone=clone)
- cloned[elem] = newelem
- return cloned[elem]
+ if not values:
+ cloned[elem] = newelem
+ return newelem
+ else:
+ return cloned[elem]
if element is not None:
element = clone(element)
@@ -547,6 +626,10 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
"'consider_as_foreign_keys' or "
"'consider_as_referenced_keys'")
+ def col_is(a, b):
+ #return a is b
+ return a.compare(b)
+
def visit_binary(binary):
if not any_operator and binary.operator is not operators.eq:
return
@@ -556,20 +639,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
if consider_as_foreign_keys:
if binary.left in consider_as_foreign_keys and \
- (binary.right is binary.left or
+ (col_is(binary.right, binary.left) or
binary.right not in consider_as_foreign_keys):
pairs.append((binary.right, binary.left))
elif binary.right in consider_as_foreign_keys and \
- (binary.left is binary.right or
+ (col_is(binary.left, binary.right) or
binary.left not in consider_as_foreign_keys):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
if binary.left in consider_as_referenced_keys and \
- (binary.right is binary.left or
+ (col_is(binary.right, binary.left) or
binary.right not in consider_as_referenced_keys):
pairs.append((binary.left, binary.right))
elif binary.right in consider_as_referenced_keys and \
- (binary.left is binary.right or
+ (col_is(binary.left, binary.right) or
binary.left not in consider_as_referenced_keys):
pairs.append((binary.right, binary.left))
else:
@@ -681,11 +764,22 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
s.c.col1 == table2.c.col1
"""
- def __init__(self, selectable, equivalents=None, include=None, exclude=None, adapt_on_names=False):
+ def __init__(self, selectable, equivalents=None,
+ include=None, exclude=None,
+ include_fn=None, exclude_fn=None,
+ adapt_on_names=False):
self.__traverse_options__ = {'stop_on':[selectable]}
self.selectable = selectable
- self.include = include
- self.exclude = exclude
+ if include:
+ assert not include_fn
+ self.include_fn = lambda e: e in include
+ else:
+ self.include_fn = include_fn
+ if exclude:
+ assert not exclude_fn
+ self.exclude_fn = lambda e: e in exclude
+ else:
+ self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
@@ -705,19 +799,17 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
return newcol
def replace(self, col):
- if isinstance(col, expression.FromClause):
- if self.selectable.is_derived_from(col):
+ if isinstance(col, expression.FromClause) and \
+ self.selectable.is_derived_from(col):
return self.selectable
-
- if not isinstance(col, expression.ColumnElement):
+ elif not isinstance(col, expression.ColumnElement):
return None
-
- if self.include and col not in self.include:
+ elif self.include_fn and not self.include_fn(col):
return None
- elif self.exclude and col in self.exclude:
+ elif self.exclude_fn and self.exclude_fn(col):
return None
-
- return self._corresponding_column(col, True)
+ else:
+ return self._corresponding_column(col, True)
class ColumnAdapter(ClauseAdapter):
"""Extends ClauseAdapter with extra utility functions.