summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py342
1 files changed, 342 insertions, 0 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
new file mode 100644
index 000000000..2c7294e66
--- /dev/null
+++ b/lib/sqlalchemy/sql/util.py
@@ -0,0 +1,342 @@
+from sqlalchemy import util, schema, topological
+from sqlalchemy.sql import expression, visitors
+
+"""Utility functions that build upon SQL and Schema constructs."""
+
+class ClauseParameters(object):
+ """Represent a dictionary/iterator of bind parameter key names/values.
+
+ Tracks the original [sqlalchemy.sql#_BindParamClause] objects as well as the
+ keys/position of each parameter, and can return parameters as a
+ dictionary or a list. Will process parameter values according to
+ the ``TypeEngine`` objects present in the ``_BindParamClause`` instances.
+ """
+
+ def __init__(self, dialect, positional=None):
+ self.dialect = dialect
+ self.__binds = {}
+ self.positional = positional or []
+
+ def get_parameter(self, key):
+ return self.__binds[key]
+
+ def set_parameter(self, bindparam, value, name):
+ self.__binds[name] = [bindparam, name, value]
+
+ def get_original(self, key):
+ return self.__binds[key][2]
+
+ def get_type(self, key):
+ return self.__binds[key][0].type
+
+ def get_processors(self):
+ """return a dictionary of bind 'processing' functions"""
+ return dict([
+ (key, value) for key, value in
+ [(
+ key,
+ self.__binds[key][0].bind_processor(self.dialect)
+ ) for key in self.__binds]
+ if value is not None
+ ])
+
+ def get_processed(self, key, processors):
+ return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2]
+
+ def keys(self):
+ return self.__binds.keys()
+
+ def __iter__(self):
+ return iter(self.keys())
+
+ def __getitem__(self, key):
+ (bind, name, value) = self.__binds[key]
+ processor = bind.bind_processor(self.dialect)
+ return processor is not None and processor(value) or value
+
+ def __contains__(self, key):
+ return key in self.__binds
+
+ def set_value(self, key, value):
+ self.__binds[key][2] = value
+
+ def get_original_dict(self):
+ return dict([(name, value) for (b, name, value) in self.__binds.values()])
+
+ def __get_processed(self, key, processors):
+ if key in processors:
+ return processors[key](self.__binds[key][2])
+ else:
+ return self.__binds[key][2]
+
+ def get_raw_list(self, processors):
+ return [self.__get_processed(key, processors) for key in self.positional]
+
+ def get_raw_dict(self, processors, encode_keys=False):
+ if encode_keys:
+ return dict([
+ (
+ key.encode(self.dialect.encoding),
+ self.__get_processed(key, processors)
+ )
+ for key in self.keys()
+ ])
+ else:
+ return dict([
+ (
+ key,
+ self.__get_processed(key, processors)
+ )
+ for key in self.keys()
+ ])
+
+ def __repr__(self):
+ return self.__class__.__name__ + ":" + repr(self.get_original_dict())
+
+
+
+class TableCollection(object):
+ def __init__(self, tables=None):
+ self.tables = tables or []
+
+ def __len__(self):
+ return len(self.tables)
+
+ def __getitem__(self, i):
+ return self.tables[i]
+
+ def __iter__(self):
+ return iter(self.tables)
+
+ def __contains__(self, obj):
+ return obj in self.tables
+
+ def __add__(self, obj):
+ return self.tables + list(obj)
+
+ def add(self, table):
+ self.tables.append(table)
+ if hasattr(self, '_sorted'):
+ del self._sorted
+
+ def sort(self, reverse=False):
+ try:
+ sorted = self._sorted
+ except AttributeError, e:
+ self._sorted = self._do_sort()
+ sorted = self._sorted
+ if reverse:
+ x = sorted[:]
+ x.reverse()
+ return x
+ else:
+ return sorted
+
+ def _do_sort(self):
+ tuples = []
+ class TVisitor(schema.SchemaVisitor):
+ def visit_foreign_key(_self, fkey):
+ if fkey.use_alter:
+ return
+ parent_table = fkey.column.table
+ if parent_table in self:
+ child_table = fkey.parent.table
+ tuples.append( ( parent_table, child_table ) )
+ vis = TVisitor()
+ for table in self.tables:
+ vis.traverse(table)
+ sorter = topological.QueueDependencySorter( tuples, self.tables )
+ head = sorter.sort()
+ sequence = []
+ def to_sequence( node, seq=sequence):
+ seq.append( node.item )
+ for child in node.children:
+ to_sequence( child )
+ if head is not None:
+ to_sequence( head )
+ return sequence
+
+
+class TableFinder(TableCollection, visitors.NoColumnVisitor):
+ """locate all Tables within a clause."""
+
+ def __init__(self, clause, check_columns=False, include_aliases=False):
+ TableCollection.__init__(self)
+ self.check_columns = check_columns
+ self.include_aliases = include_aliases
+ for clause in util.to_list(clause):
+ self.traverse(clause)
+
+ def visit_alias(self, alias):
+ if self.include_aliases:
+ self.tables.append(alias)
+
+ def visit_table(self, table):
+ self.tables.append(table)
+
+ def visit_column(self, column):
+ if self.check_columns:
+ self.tables.append(column.table)
+
+class ColumnFinder(visitors.ClauseVisitor):
+ def __init__(self):
+ self.columns = util.Set()
+
+ def visit_column(self, c):
+ self.columns.add(c)
+
+ def __iter__(self):
+ return iter(self.columns)
+
+class ColumnsInClause(visitors.ClauseVisitor):
+ """Given a selectable, visit clauses and determine if any columns
+ from the clause are in the selectable.
+ """
+
+ def __init__(self, selectable):
+ self.selectable = selectable
+ self.result = False
+
+ def visit_column(self, column):
+ if self.selectable.c.get(column.key) is column:
+ self.result = True
+
+class AbstractClauseProcessor(visitors.NoColumnVisitor):
+ """Traverse a clause and attempt to convert the contents of container elements
+ to a converted element.
+
+ The conversion operation is defined by subclasses.
+ """
+
+ def convert_element(self, elem):
+ """Define the *conversion* method for this ``AbstractClauseProcessor``."""
+
+ raise NotImplementedError()
+
+ def copy_and_process(self, list_):
+ """Copy the container elements in the given list to a new list and
+ process the new list.
+ """
+
+ list_ = list(list_)
+ self.process_list(list_)
+ return list_
+
+ def process_list(self, list_):
+ """Process all elements of the given list in-place."""
+
+ for i in range(0, len(list_)):
+ elem = self.convert_element(list_[i])
+ if elem is not None:
+ list_[i] = elem
+ else:
+ list_[i] = self.traverse(list_[i], clone=True)
+
+ def visit_grouping(self, grouping):
+ elem = self.convert_element(grouping.elem)
+ if elem is not None:
+ grouping.elem = elem
+
+ def visit_clauselist(self, clist):
+ for i in range(0, len(clist.clauses)):
+ n = self.convert_element(clist.clauses[i])
+ if n is not None:
+ clist.clauses[i] = n
+
+ def visit_unary(self, unary):
+ elem = self.convert_element(unary.element)
+ if elem is not None:
+ unary.element = elem
+
+ def visit_binary(self, binary):
+ elem = self.convert_element(binary.left)
+ if elem is not None:
+ binary.left = elem
+ elem = self.convert_element(binary.right)
+ if elem is not None:
+ binary.right = elem
+
+ def visit_join(self, join):
+ elem = self.convert_element(join.left)
+ if elem is not None:
+ join.left = elem
+ elem = self.convert_element(join.right)
+ if elem is not None:
+ join.right = elem
+ join._init_primary_key()
+
+ def visit_select(self, select):
+ fr = util.OrderedSet()
+ for elem in select._froms:
+ n = self.convert_element(elem)
+ if n is not None:
+ fr.add((elem, n))
+ select._recorrelate_froms(fr)
+
+ col = []
+ for elem in select._raw_columns:
+ n = self.convert_element(elem)
+ if n is None:
+ col.append(elem)
+ else:
+ col.append(n)
+ select._raw_columns = col
+
+class ClauseAdapter(AbstractClauseProcessor):
+ """Given a clause (like as in a WHERE criterion), locate columns
+ which are embedded within a given selectable, and changes those
+ columns to be that of the selectable.
+
+ E.g.::
+
+ table1 = Table('sometable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+ table2 = Table('someothertable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+
+ condition = table1.c.col1 == table2.c.col1
+
+ and make an alias of table1::
+
+ s = table1.alias('foo')
+
+ calling ``ClauseAdapter(s).traverse(condition)`` converts
+ condition to read::
+
+ s.c.col1 == table2.c.col1
+ """
+
+ def __init__(self, selectable, include=None, exclude=None, equivalents=None):
+ self.selectable = selectable
+ self.include = include
+ self.exclude = exclude
+ self.equivalents = equivalents
+
+ def convert_element(self, col):
+ if isinstance(col, expression.FromClause):
+ if self.selectable.is_derived_from(col):
+ return self.selectable
+ if not isinstance(col, expression.ColumnElement):
+ return None
+ if self.include is not None:
+ if col not in self.include:
+ return None
+ if self.exclude is not None:
+ if col in self.exclude:
+ return None
+ newcol = self.selectable.corresponding_column(col, raiseerr=False, require_embedded=True, keys_ok=False)
+ if newcol is None and self.equivalents is not None and col in self.equivalents:
+ for equiv in self.equivalents[col]:
+ newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False)
+ if newcol:
+ return newcol
+ #if newcol is None:
+ # self.traverse(col)
+ # return col
+ return newcol
+
+