summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r--lib/sqlalchemy/orm/util.py187
1 files changed, 184 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 3b3b9b7ed..d248c0dd0 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -4,7 +4,8 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions
+from sqlalchemy import sql, util, exceptions, sql_util
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_PASS
all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire", "none"])
@@ -89,8 +90,6 @@ class TranslatingDict(dict):
def __translate_col(self, col):
ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False)
-# if col is not ourcol and ourcol is not None:
-# print "TD TRANSLATING ", col, "TO", ourcol
if ourcol is None:
return col
else:
@@ -111,6 +110,56 @@ class TranslatingDict(dict):
def setdefault(self, col, value):
return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)
+class ExtensionCarrier(MapperExtension):
+ def __init__(self, _elements=None):
+ self.__elements = _elements or []
+
+ def copy(self):
+ return ExtensionCarrier(list(self.__elements))
+
+ def __iter__(self):
+ return iter(self.__elements)
+
+ def insert(self, extension):
+ """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
+
+ self.__elements.insert(0, extension)
+
+ def append(self, extension):
+ """Append a MapperExtension at the end of this ExtensionCarrier's list."""
+
+ self.__elements.append(extension)
+
+ def _create_do(funcname):
+ def _do(self, *args, **kwargs):
+ for elem in self.__elements:
+ ret = getattr(elem, funcname)(*args, **kwargs)
+ if ret is not EXT_PASS:
+ return ret
+ else:
+ return EXT_PASS
+ return _do
+
+ init_instance = _create_do('init_instance')
+ init_failed = _create_do('init_failed')
+ dispose_class = _create_do('dispose_class')
+ get_session = _create_do('get_session')
+ load = _create_do('load')
+ get = _create_do('get')
+ get_by = _create_do('get_by')
+ select_by = _create_do('select_by')
+ select = _create_do('select')
+ translate_row = _create_do('translate_row')
+ create_instance = _create_do('create_instance')
+ append_result = _create_do('append_result')
+ populate_instance = _create_do('populate_instance')
+ before_insert = _create_do('before_insert')
+ before_update = _create_do('before_update')
+ after_update = _create_do('after_update')
+ after_insert = _create_do('after_insert')
+ before_delete = _create_do('before_delete')
+ after_delete = _create_do('after_delete')
+
class BinaryVisitor(sql.ClauseVisitor):
def __init__(self, func):
self.func = func
@@ -118,6 +167,138 @@ class BinaryVisitor(sql.ClauseVisitor):
def visit_binary(self, binary):
self.func(binary)
+class AliasedClauses(object):
+ """Creates aliases of a mapped tables for usage in ORM queries.
+ """
+
+ def __init__(self, mapped_table, alias=None):
+ if alias:
+ self.alias = alias
+ else:
+ self.alias = mapped_table.alias()
+ self.mapped_table = mapped_table
+ self.extra_cols = {}
+ self.row_decorator = self._create_row_adapter()
+
+ def aliased_column(self, column):
+ """return the aliased version of the given column, creating a new label for it if not already
+ present in this AliasedClauses."""
+
+ conv = self.alias.corresponding_column(column, raiseerr=False)
+ if conv:
+ return conv
+
+ if column in self.extra_cols:
+ return self.extra_cols[column]
+
+ aliased_column = column
+ # for column-level subqueries, swap out its selectable with our
+ # eager version as appropriate, and manually build the
+ # "correlation" list of the subquery.
+ class ModifySubquery(sql.ClauseVisitor):
+ def visit_select(s, select):
+ select._should_correlate = False
+ select.append_correlation(self.alias)
+ aliased_column = sql_util.ClauseAdapter(self.alias).chain(ModifySubquery()).traverse(aliased_column, clone=True)
+ aliased_column = aliased_column.label(None)
+ self.row_decorator.map[column] = aliased_column
+ # TODO: this is a little hacky
+ for attr in ('name', '_label'):
+ if hasattr(column, attr):
+ self.row_decorator.map[getattr(column, attr)] = aliased_column
+ self.extra_cols[column] = aliased_column
+ return aliased_column
+
+ def adapt_clause(self, clause):
+ return self.aliased_column(clause)
+# return sql_util.ClauseAdapter(self.alias).traverse(clause, clone=True)
+
+ def _create_row_adapter(self):
+ """Return a callable which,
+ when passed a RowProxy, will return a new dict-like object
+ that translates Column objects to that of this object's Alias before calling upon the row.
+
+ This allows a regular Table to be used to target columns in a row that was in reality generated from an alias
+ of that table, in such a way that the row can be passed to logic which knows nothing about the aliased form
+ of the table.
+ """
+ class AliasedRowAdapter(object):
+ def __init__(self, row):
+ self.row = row
+ def __contains__(self, key):
+ return key in map or key in self.row
+ def has_key(self, key):
+ return key in self
+ def __getitem__(self, key):
+ if key in map:
+ key = map[key]
+ return self.row[key]
+ def keys(self):
+ return map.keys()
+ map = {}
+ for c in self.alias.c:
+ parent = self.mapped_table.corresponding_column(c)
+ map[parent] = c
+ map[parent._label] = c
+ map[parent.name] = c
+ for c in self.extra_cols:
+ map[c] = self.extra_cols[c]
+ # TODO: this is a little hacky
+ for attr in ('name', '_label'):
+ if hasattr(c, attr):
+ map[getattr(c, attr)] = self.extra_cols[c]
+
+ AliasedRowAdapter.map = map
+ return AliasedRowAdapter
+
+
+class PropertyAliasedClauses(AliasedClauses):
+ """extends AliasedClauses to add support for primary/secondary joins on a relation()."""
+
+ def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None):
+ super(PropertyAliasedClauses, self).__init__(prop.select_table)
+
+ self.parentclauses = parentclauses
+ if parentclauses is not None:
+ self.path = parentclauses.path + (prop.parent, prop.key)
+ else:
+ self.path = (prop.parent, prop.key)
+
+ self.prop = prop
+
+ if prop.secondary:
+ self.secondary = prop.secondary.alias()
+ if parentclauses is not None:
+ aliasizer = sql_util.ClauseAdapter(self.alias).\
+ chain(sql_util.ClauseAdapter(self.secondary)).\
+ chain(sql_util.ClauseAdapter(parentclauses.alias))
+ else:
+ aliasizer = sql_util.ClauseAdapter(self.alias).\
+ chain(sql_util.ClauseAdapter(self.secondary))
+ self.secondaryjoin = aliasizer.traverse(secondaryjoin, clone=True)
+ self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True)
+ else:
+ if parentclauses is not None:
+ aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
+ aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side))
+ else:
+ aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
+ self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True)
+ self.secondary = None
+ self.secondaryjoin = None
+
+ if prop.order_by:
+ self.order_by = sql_util.ClauseAdapter(self.alias).copy_and_process(util.to_list(prop.order_by))
+ else:
+ self.order_by = None
+
+ mapper = property(lambda self:self.prop.mapper)
+ table = property(lambda self:self.prop.select_table)
+
+ def __str__(self):
+ return "->".join([str(s) for s in self.path])
+
+
def instance_str(instance):
"""Return a string describing an instance."""