summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/engine/base.py4
-rw-r--r--lib/sqlalchemy/orm/__init__.py56
-rw-r--r--lib/sqlalchemy/orm/dependency.py1
-rw-r--r--lib/sqlalchemy/orm/interfaces.py35
-rw-r--r--lib/sqlalchemy/orm/mapper.py325
-rw-r--r--lib/sqlalchemy/orm/properties.py152
-rw-r--r--lib/sqlalchemy/orm/query.py939
-rw-r--r--lib/sqlalchemy/orm/strategies.py42
-rw-r--r--lib/sqlalchemy/orm/util.py70
-rw-r--r--lib/sqlalchemy/sql/expression.py9
-rw-r--r--lib/sqlalchemy/sql/util.py83
-rw-r--r--lib/sqlalchemy/sql/visitors.py37
-rw-r--r--lib/sqlalchemy/util.py58
13 files changed, 923 insertions, 888 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 606cddc0a..5e8455b15 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1479,11 +1479,13 @@ class ResultProxy(object):
if isinstance(key, basestring):
key = key.lower()
-
try:
rec = props[key]
except KeyError:
# fallback for targeting a ColumnElement to a textual expression
+ # it would be nice to get rid of this but we make use of it in the case where
+ # you say something like query.options(contains_alias('fooalias')) - the matching
+ # is done on strings
if isinstance(key, expression.ColumnElement):
if key._label.lower() in props:
return props[key._label.lower()]
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 011b6e360..39eb40daa 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -481,9 +481,13 @@ def mapper(class_, local_table=None, *args, **params):
polymorphic_on
Used with mappers in an inheritance relationship, a ``Column``
which will identify the class/mapper combination to be used
- with a particular row. requires the polymorphic_identity
+ with a particular row. Requires the ``polymorphic_identity``
value to be set for all mappers in the inheritance
- hierarchy.
+ hierarchy. The column specified by ``polymorphic_on`` is
+ usually a column that resides directly within the base
+ mapper's mapped table; alternatively, it may be a column
+ that is only present within the <selectable> portion
+ of the ``with_polymorphic`` argument.
_polymorphic_map
Used internally to propagate the full map of polymorphic
@@ -497,7 +501,7 @@ def mapper(class_, local_table=None, *args, **params):
polymorphic_fetch
specifies how subclasses mapped through joined-table
inheritance will be fetched. options are 'union',
- 'select', and 'deferred'. if the select_table argument
+ 'select', and 'deferred'. if the 'with_polymorphic' argument
is present, defaults to 'union', otherwise defaults to
'select'.
@@ -529,12 +533,26 @@ def mapper(class_, local_table=None, *args, **params):
to be used against this mapper's selectable unit. This is
normally simply the primary key of the `local_table`, but
can be overridden here.
-
+
+ with_polymorphic
+ A tuple in the form ``(<classes>, <selectable>)`` indicating the
+ default style of "polymorphic" loading, that is, which tables
+ are queried at once. <classes> is any single or list of mappers
+ and/or classes indicating the inherited classes that should be
+ loaded at once. The special value ``'*'`` may be used to indicate
+ all descending classes should be loaded immediately. The second
+ tuple argument <selectable> indicates a selectable that will be
+ used to query for multiple classes. Normally, it is left as
+ None, in which case this mapper will form an outer join from
+ the base mapper's table to that of all desired sub-mappers.
+ When specified, it provides the selectable to be used for
+ polymorphic loading. When with_polymorphic includes mappers
+ which load from a "concrete" inheriting table, the <selectable>
+ argument is required, since it usually requires more complex
+ UNION queries.
+
select_table
- A [sqlalchemy.schema#Table] or any [sqlalchemy.sql#Selectable]
- which will be used to select instances of this mapper's class.
- usually used to provide polymorphic loading among several
- classes in an inheritance hierarchy.
+ Deprecated. Synonymous with ``with_polymorphic=('*', <selectable>)`.
version_id_col
A ``Column`` which must have an integer type that will be
@@ -691,9 +709,6 @@ def lazyload(name, mapper=None):
return strategies.EagerLazyOption(name, lazy=True, mapper=mapper)
-def fetchmode(name, type):
- return strategies.FetchModeOption(name, type)
-
def noload(name):
"""Return a ``MapperOption`` that will convert the property of the
given name into a non-load.
@@ -715,21 +730,14 @@ def contains_alias(alias):
def __init__(self, alias):
self.alias = alias
if isinstance(self.alias, basestring):
- self.selectable = None
+ self.translator = None
else:
- self.selectable = alias
- self._row_translators = {}
- def get_selectable(self, mapper):
- if self.selectable is None:
- self.selectable = mapper.mapped_table.alias(self.alias)
- return self.selectable
+ self.translator = create_row_adapter(alias)
+
def translate_row(self, mapper, context, row):
- if mapper in self._row_translators:
- return self._row_translators[mapper](row)
- else:
- translator = create_row_adapter(self.get_selectable(mapper), mapper.mapped_table)
- self._row_translators[mapper] = translator
- return translator(row)
+ if not self.translator:
+ self.translator = create_row_adapter(mapper.mapped_table.alias(self.alias))
+ return self.translator(row)
return ExtensionOption(AliasedRow(alias))
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index b6b6ce940..8519d2260 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -34,7 +34,6 @@ class DependencyProcessor(object):
self.cascade = prop.cascade
self.mapper = prop.mapper
self.parent = prop.parent
- self.association = prop.association
self.secondary = prop.secondary
self.direction = prop.direction
self.is_backref = prop.is_backref
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index ef263da4c..010a8002a 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -417,7 +417,12 @@ class MapperProperty(object):
return operator(self.comparator, value)
class PropComparator(expression.ColumnOperators):
- """defines comparison operations for MapperProperty objects"""
+ """defines comparison operations for MapperProperty objects.
+
+ PropComparator instances should also define an accessor 'property'
+ which returns the MapperProperty associated with this
+ PropComparator.
+ """
def expression_element(self):
return self.clause_element()
@@ -492,6 +497,7 @@ class PropComparator(expression.ColumnOperators):
return self.operate(PropComparator.has_op, criterion, **kwargs)
+
class StrategizedProperty(MapperProperty):
"""A MapperProperty which uses selectable strategies to affect
loading behavior.
@@ -618,20 +624,37 @@ class PropertyOption(MapperOption):
mapper = self.mapper
if isinstance(self.mapper, type):
mapper = class_mapper(mapper)
- if mapper is not query.mapper and mapper not in [q[0] for q in query._entities]:
- raise exceptions.ArgumentError("Can't find entity %s in Query. Current list: %r" % (str(mapper), [str(m) for m in [query.mapper] + query._entities]))
+ if mapper is not query.mapper and mapper not in [q.mapper for q in query._entities]:
+ raise exceptions.ArgumentError("Can't find entity %s in Query. Current list: %r" % (str(mapper), [str(m) for m in query._entities]))
else:
mapper = query.mapper
- for token in self.key.split('.'):
+ if isinstance(self.key, basestring):
+ tokens = self.key.split('.')
+ else:
+ tokens = util.to_list(self.key)
+
+ for token in tokens:
+ if isinstance(token, basestring):
+ prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
+ elif isinstance(token, PropComparator):
+ prop = token.property
+ token = prop.key
+
+ else:
+ raise exceptions.ArgumentError("mapper option expects string key or list of attributes")
+
if current_path and token == current_path[1]:
current_path = current_path[2:]
continue
- prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
+
if prop is None:
return []
path = build_path(mapper, prop.key, path)
l.append(path)
- mapper = getattr(prop, 'mapper', None)
+ if getattr(token, '_of_type', None):
+ mapper = token._of_type
+ else:
+ mapper = getattr(prop, 'mapper', None)
return l
PropertyOption.logger = logging.class_logger(PropertyOption)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 7e24c27c2..fab7000ed 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -71,6 +71,7 @@ class Mapper(object):
polymorphic_fetch=None,
concrete=False,
select_table=None,
+ with_polymorphic=None,
allow_null_pks=False,
batch=True,
column_prefix=None,
@@ -81,20 +82,9 @@ class Mapper(object):
Mappers are normally constructed via the [sqlalchemy.orm#mapper()]
function. See for details.
+
"""
- if not issubclass(class_, object):
- raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
-
- for table in (local_table, select_table):
- if table and isinstance(table, expression._SelectBaseMixin):
- # some db's, noteably postgres, dont want to select from a select
- # without an alias. also if we make our own alias internally, then
- # the configured properties on the mapper are not matched against the alias
- # we make, theres workarounds but it starts to get really crazy (its crazy enough
- # the SQL that gets generated) so just require an alias
- raise exceptions.ArgumentError("Mapping against a Select object requires that it has a name. Use an alias to give it a name, i.e. s = select(...).alias('myselect')")
-
self.class_ = class_
self.entity_name = entity_name
self.primary_key_argument = primary_key
@@ -105,7 +95,6 @@ class Mapper(object):
self.concrete = concrete
self.single = False
self.inherits = inherits
- self.select_table = select_table
self.local_table = local_table
self.inherit_condition = inherit_condition
self.inherit_foreign_keys = inherit_foreign_keys
@@ -119,9 +108,37 @@ class Mapper(object):
self.column_prefix = column_prefix
self.polymorphic_on = polymorphic_on
self._eager_loaders = util.Set()
- self._row_translators = {}
self._dependency_processors = []
self._clause_adapter = None
+ self._requires_row_aliasing = False
+
+ if not issubclass(class_, object):
+ raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
+
+ self.select_table = select_table
+ if select_table:
+ if with_polymorphic:
+ raise exceptions.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)")
+ self.with_polymorphic = ('*', select_table)
+ else:
+ if with_polymorphic == '*':
+ self.with_polymorphic = ('*', None)
+ elif isinstance(with_polymorphic, (tuple, list)):
+ if isinstance(with_polymorphic[0], (basestring, tuple, list)):
+ self.with_polymorphic = with_polymorphic
+ else:
+ self.with_polymorphic = (with_polymorphic, None)
+ elif with_polymorphic is not None:
+ raise exceptions.ArgumentError("Invalid setting for with_polymorphic")
+ else:
+ self.with_polymorphic = None
+
+ check_tables = [self.local_table]
+ if self.with_polymorphic:
+ check_tables.append(self.with_polymorphic[1])
+ for table in check_tables:
+ if table and isinstance(table, expression._SelectBaseMixin):
+ raise exceptions.ArgumentError("Mapping against a Select object requires that it has a name. Use an alias to give it a name, i.e. s = select(...).alias('myselect')")
# our 'polymorphic identity', a string name that when located in a result set row
# indicates this Mapper should be used to construct the object instance for that row.
@@ -130,7 +147,7 @@ class Mapper(object):
if polymorphic_fetch not in (None, 'union', 'select', 'deferred'):
raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch)
if polymorphic_fetch is None:
- self.polymorphic_fetch = (self.select_table is None) and 'select' or 'union'
+ self.polymorphic_fetch = (self.with_polymorphic is None) and 'select' or 'union'
else:
self.polymorphic_fetch = polymorphic_fetch
@@ -149,10 +166,6 @@ class Mapper(object):
# a set of all mappers which inherit from this one.
self._inheriting_mappers = util.Set()
- # a second mapper that is used for selecting, if the "select_table" argument
- # was sent to this mapper.
- self.__surrogate_mapper = None
-
self.__props_init = False
self.__should_log_info = logging.is_info_enabled(self.logger)
@@ -161,10 +174,8 @@ class Mapper(object):
self._compile_class()
self._compile_inheritance()
self._compile_extensions()
- self._compile_tables()
self._compile_properties()
self._compile_pks()
- self._compile_selectable()
global __new_mappers
__new_mappers = True
self.__log("constructed")
@@ -178,10 +189,13 @@ class Mapper(object):
self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self.non_primary and "|non-primary" or "") + ") " + msg)
def _is_orphan(self, obj):
- for (key,klass) in self.delete_orphans:
- if attributes.has_parent(klass, obj, key, optimistic=has_identity(obj)):
- return False
- return bool(self.delete_orphans)
+ o = False
+ for mapper in self.iterate_to_root():
+ for (key,klass) in mapper.delete_orphans:
+ if attributes.has_parent(klass, obj, key, optimistic=has_identity(obj)):
+ return False
+ o = o or bool(mapper.delete_orphans)
+ return o
def get_property(self, key, resolve_synonyms=False, raiseerr=True):
"""return a MapperProperty associated with the given key."""
@@ -205,11 +219,90 @@ class Mapper(object):
return self.__props.itervalues()
iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.")
+ def __adjust_wp_selectable(self, spec=None, selectable=False):
+ """given a with_polymorphic() argument, resolve it against this mapper's with_polymorphic setting"""
+
+ isdefault = False
+ if self.with_polymorphic:
+ isdefault = not spec and selectable is False
+
+ if not spec:
+ spec = self.with_polymorphic[0]
+ if selectable is False:
+ selectable = self.with_polymorphic[1]
+
+ return spec, selectable, isdefault
+
+ def __mappers_from_spec(self, spec, selectable):
+ """given a with_polymorphic() argument, return the set of mappers it represents.
+
+ Trims the list of mappers to just those represented within the given selectable, if present.
+ This helps some more legacy-ish mappings.
+
+ """
+ if spec == '*':
+ mappers = list(self.polymorphic_iterator())
+ elif spec:
+ mappers = [_class_to_mapper(m) for m in util.to_list(spec)]
+ else:
+ mappers = []
+
+ if selectable:
+ tables = util.Set(sqlutil.find_tables(selectable))
+ mappers = [m for m in mappers if m.local_table in tables]
+
+ return mappers
+ __mappers_from_spec = util.conditional_cache_decorator(__mappers_from_spec)
+
+ def __selectable_from_mappers(self, mappers):
+ """given a list of mappers (assumed to be within this mapper's inheritance hierarchy),
+ construct an outerjoin amongst those mapper's mapped tables.
+
+ """
+ from_obj = self.mapped_table
+ for m in mappers:
+ if m is self:
+ continue
+ if m.concrete:
+ raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.")
+ elif not m.single:
+ from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition)
+
+ return from_obj
+ __selectable_from_mappers = util.conditional_cache_decorator(__selectable_from_mappers)
+
+ def _with_polymorphic_mappers(self, spec=None, selectable=False):
+ spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
+ return self.__mappers_from_spec(spec, selectable, cache=isdefault)
+
+ def _with_polymorphic_selectable(self, spec=None, selectable=False):
+ spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
+ if selectable:
+ return selectable
+ else:
+ return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable, cache=isdefault), cache=isdefault)
+
+ def _with_polymorphic_args(self, spec=None, selectable=False):
+ spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
+ mappers = self.__mappers_from_spec(spec, selectable, cache=isdefault)
+ if selectable:
+ return mappers, selectable
+ else:
+ return mappers, self.__selectable_from_mappers(mappers, cache=isdefault)
+
+ def _iterate_polymorphic_properties(self, spec=None, selectable=False):
+ return iter(util.OrderedSet(
+ chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers(spec, selectable)])
+ ))
+
def properties(self):
raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.")
properties = property(properties)
- compiled = property(lambda self:self.__props_init, doc="return True if this mapper is compiled")
+ def compiled(self):
+ """return True if this mapper is compiled"""
+ return self.__props_init
+ compiled = property(compiled)
def dispose(self):
# disaable any attribute-based compilation
@@ -224,7 +317,11 @@ class Mapper(object):
attributes.unregister_class(self.class_)
def compile(self):
- """Compile this mapper into its final internal format.
+ """Compile this mapper and all other non-compiled mappers.
+
+ This method checks the local compiled status as well as for
+ any new mappers that have been defined, and is safe to call
+ repeatedly.
"""
global __new_mappers
@@ -250,10 +347,9 @@ class Mapper(object):
def __initialize_properties(self):
"""Call the ``init()`` method on all ``MapperProperties``
attached to this mapper.
-
- This happens after all mappers have completed compiling
- everything else up until this point, so that all dependencies
- are fully available.
+
+ This is a deferred configuration step which is intended
+ to execute once all mappers have been constructed.
"""
self.__log("_initialize_properties() started")
@@ -298,14 +394,7 @@ class Mapper(object):
self.extension.append(ext)
def _compile_inheritance(self):
- """Determine if this Mapper inherits from another mapper, and
- if so calculates the mapped_table for this Mapper taking the
- inherited mapper into account.
-
- For joined table inheritance, creates a ``SyncRule`` that will
- synchronize column values between the joined tables. also
- initializes polymorphic variables used in polymorphic loads.
- """
+ """Configure settings related to inherting and/or inherited mappers being present."""
if self.inherits:
if isinstance(self.inherits, type):
@@ -323,8 +412,11 @@ class Mapper(object):
self.single = True
if not self.local_table is self.inherits.local_table:
if self.concrete:
- self._synchronizer= None
+ self._synchronizer = None
self.mapped_table = self.local_table
+ for mapper in self.iterate_to_root():
+ if mapper.polymorphic_on:
+ mapper._requires_row_aliasing = True
else:
if self.inherit_condition is None:
# figure out inherit condition from our table to the immediate table
@@ -366,8 +458,7 @@ class Mapper(object):
self.version_id_col = self.inherits.version_id_col
for mapper in self.iterate_to_root():
- if hasattr(mapper, '_genned_equivalent_columns'):
- del mapper._genned_equivalent_columns
+ util.reset_cached(mapper, '_equivalent_columns')
if self.order_by is False:
self.order_by = self.inherits.order_by
@@ -386,32 +477,17 @@ class Mapper(object):
raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
self.polymorphic_map[self.polymorphic_identity] = self
self._identity_class = self.class_
-
+
if self.mapped_table is None:
raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self))
- def _compile_tables(self):
- # summary of the various Selectable units:
- # mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table)
- # local_table - the Selectable that was passed to this Mapper's constructor, if any
- # select_table - the Selectable that will be used during queries. if this is specified
- # as a constructor keyword argument, it takes precendence over mapped_table, otherwise its mapped_table
- # this is either select_table if it was given explicitly, or in the case of a mapper that inherits
- # its local_table
- # tables - a collection of underlying Table objects pulled from mapped_table
-
- if self.select_table is None:
- self.select_table = self.mapped_table
-
- # locate all tables contained within the "table" passed in, which
- # may be a join or other construct
+ def _compile_pks(self):
+
self.tables = sqlutil.find_tables(self.mapped_table)
if not self.tables:
raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
- def _compile_pks(self):
-
self._pks_by_table = {}
self._cols_by_table = {}
@@ -439,7 +515,6 @@ class Mapper(object):
if self.inherits and not self.concrete and not self.primary_key_argument:
# if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit)
self.primary_key = self.inherits.primary_key
- self._get_clause = self.inherits._get_clause
else:
# determine primary key from argument or mapped_table pks - reduce to the minimal set of columns
if self.primary_key_argument:
@@ -453,18 +528,17 @@ class Mapper(object):
self.primary_key = primary_key
self.__log("Identified primary key columns: " + str(primary_key))
- # create a "get clause" based on the primary key. this is used
- # by query.get() and many-to-one lazyloads to load this item
- # by primary key.
- _get_clause = sql.and_()
- _get_params = {}
- for primary_key in self.primary_key:
- bind = sql.bindparam(None, type_=primary_key.type)
- _get_params[primary_key] = bind
- _get_clause.clauses.append(primary_key == bind)
- self._get_clause = (_get_clause, _get_params)
-
- def __get_equivalent_columns(self):
+ def _get_clause(self):
+ """create a "get clause" based on the primary key. this is used
+ by query.get() and many-to-one lazyloads to load this item
+ by primary key.
+
+ """
+ params = dict([(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key])
+ return sql.and_(*[k==v for (k, v) in params.iteritems()]), params
+ _get_clause = property(util.cache_decorator(_get_clause))
+
+ def _equivalent_columns(self):
"""Create a map of all *equivalent* columns, based on
the determination of column pairs that are equated to
one another either by an established foreign key relationship
@@ -529,13 +603,7 @@ class Mapper(object):
equivs(col, util.Set(), col)
return result
- def _equivalent_columns(self):
- if hasattr(self, '_genned_equivalent_columns'):
- return self._genned_equivalent_columns
- else:
- self._genned_equivalent_columns = self.__get_equivalent_columns()
- return self._genned_equivalent_columns
- _equivalent_columns = property(_equivalent_columns)
+ _equivalent_columns = property(util.cache_decorator(_equivalent_columns))
class _CompileOnAttr(PropComparator):
"""A placeholder descriptor which triggers compilation on access."""
@@ -606,7 +674,13 @@ class Mapper(object):
column_key = (self.column_prefix or '') + column.key
self._compile_property(column_key, column, init=False, setparent=True)
-
+
+ # do a special check for the "discriminiator" column, as it may only be present
+ # in the 'with_polymorphic' selectable but we need it for the base mapper
+ if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty:
+ col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on
+ self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True)
+
def _adapt_inherited_property(self, key, prop):
if not self.concrete:
self._compile_property(key, prop, init=False, setparent=False)
@@ -692,44 +766,10 @@ class Mapper(object):
if init:
prop.init(key, self)
-
+
for mapper in self._inheriting_mappers:
mapper._adapt_inherited_property(key, prop)
- def _compile_selectable(self):
- """If the 'select_table' keyword argument was specified, set
- up a second *surrogate mapper* that will be used for select
- operations.
-
- The columns of `select_table` should encompass all the columns
- of the `mapped_table` either directly or through proxying
- relationships. Currently, non-column properties are **not**
- copied. This implies that a polymorphic mapper can't do any
- eager loading right now.
- """
-
- if self.select_table is not self.mapped_table:
- # turn a straight join into an aliased selectable
- if isinstance(self.select_table, sql.Join):
- self.select_table = self.select_table.select(use_labels=True).alias()
-
- self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=self.primary_key_argument)
- adapter = sqlutil.ClauseAdapter(self.select_table, equivalents=self.__surrogate_mapper._equivalent_columns)
-
- if self.order_by:
- order_by = [expression._literal_as_text(o) for o in util.to_list(self.order_by) or []]
- order_by = adapter.copy_and_process(order_by)
- self.__surrogate_mapper.order_by=order_by
-
- if self._init_properties:
- for key, prop in self._init_properties.iteritems():
- if expression.is_column(prop):
- self.__surrogate_mapper.add_property(key, _corresponding_column_or_error(self.select_table, prop))
- elif (isinstance(prop, list) and expression.is_column(prop[0])):
- self.__surrogate_mapper.add_property(key, [_corresponding_column_or_error(self.select_table, c) for c in prop])
-
- self.__surrogate_mapper._clause_adapter = adapter
-
def _compile_class(self):
"""If this mapper is to be a primary mapper (i.e. the
non_primary flag is not set), associate this Mapper with the
@@ -1280,16 +1320,7 @@ class Mapper(object):
for (c, m) in prop.cascade_iterator(type, state, recursive, halt_on=halt_on):
yield (c, m)
- def get_select_mapper(self):
- """Return the mapper used for issuing selects.
-
- This mapper is the same mapper as `self` unless the
- select_table argument was specified for this mapper.
- """
-
- return self.__surrogate_mapper or self
-
- def _instance(self, context, row, result=None, skip_polymorphic=False, extension=None, only_load_props=None, refresh_instance=None):
+ def _instance(self, context, row, result=None, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None):
if not extension:
extension = self.extension
@@ -1298,15 +1329,25 @@ class Mapper(object):
if ret is not EXT_CONTINUE:
row = ret
- if not refresh_instance and not skip_polymorphic and self.polymorphic_on:
+ if polymorphic_from:
+ # if we are called from a base mapper doing a polymorphic load, figure out what tables,
+ # if any, will need to be "post-fetched" based on the tables present in the row,
+ # or from the options set up on the query
+ if ('polymorphic_fetch', self) not in context.attributes:
+ if self in context.query._with_polymorphic:
+ context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [])
+ else:
+ context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [t for t in self.tables if t not in polymorphic_from.tables])
+
+ elif not refresh_instance and self.polymorphic_on:
discriminator = row[self.polymorphic_on]
- if discriminator:
- mapper = self.polymorphic_map[discriminator]
+ if discriminator is not None:
+ try:
+ mapper = self.polymorphic_map[discriminator]
+ except KeyError:
+ raise exceptions.AssertionError("No such polymorphic_identity %r is defined" % discriminator)
if mapper is not self:
- if ('polymorphic_fetch', mapper) not in context.attributes:
- context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
- row = self.translate_row(mapper, row)
- return mapper._instance(context, row, result=result, skip_polymorphic=True)
+ return mapper._instance(context, row, result=result, polymorphic_from=self)
# determine identity key
if refresh_instance:
@@ -1319,7 +1360,7 @@ class Mapper(object):
identitykey = self._identity_key_from_state(refresh_instance)
else:
identitykey = self.identity_key_from_row(row)
-
+
session_identity_map = context.session.identity_map
if identitykey in session_identity_map:
@@ -1404,22 +1445,6 @@ class Mapper(object):
return instance
- def translate_row(self, tomapper, row):
- """Translate the column keys of a row into a new or proxied
- row that can be understood by another mapper.
-
- This can be used in conjunction with populate_instance to
- populate an instance using an alternate mapper.
- """
-
- if tomapper in self._row_translators:
- # row translators are cached based on target mapper
- return self._row_translators[tomapper](row)
- else:
- translator = create_row_adapter(self.mapped_table, tomapper.mapped_table, equivalent_columns=self._equivalent_columns)
- self._row_translators[tomapper] = translator
- return translator(row)
-
def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags):
"""populate an instance from a result row."""
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 13ed83bc9..15546d7e8 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -233,7 +233,6 @@ class PropertyLoader(StrategizedProperty):
self.passive_updates = passive_updates
self.remote_side = util.to_set(remote_side)
self.enable_typechecks = enable_typechecks
- self.__parent_join_cache = weakref.WeakKeyDictionary()
self.comparator = PropertyLoader.Comparator(self)
self.join_depth = join_depth
self.strategy_class = strategy_class
@@ -304,13 +303,12 @@ class PropertyLoader(StrategizedProperty):
if getattr(self, '_of_type', None):
target_mapper = self._of_type
- to_selectable = target_mapper.select_table
+ to_selectable = target_mapper.mapped_table
adapt_against = to_selectable
else:
target_mapper = self.prop.mapper
to_selectable = None
- if target_mapper.select_table is not target_mapper.mapped_table:
- adapt_against = target_mapper.select_table
+ adapt_against = None
if self.prop._is_self_referential():
pj = self.prop.primary_join_against(self.prop.parent, None)
@@ -471,8 +469,8 @@ class PropertyLoader(StrategizedProperty):
def _get_target_class(self):
"""Return the target class of the relation, even if the
property has not been initialized yet.
- """
+ """
if isinstance(self.argument, type):
return self.argument
else:
@@ -484,7 +482,6 @@ class PropertyLoader(StrategizedProperty):
self._determine_fks()
self._determine_direction()
self._determine_remote_side()
- self._create_polymorphic_joins()
self._post_init()
def _determine_targets(self):
@@ -498,9 +495,6 @@ class PropertyLoader(StrategizedProperty):
else:
raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument)))
- # ensure the "select_mapper", if different from the regular target mapper, is compiled.
- self.mapper.get_select_mapper()
-
if not self.parent.concrete:
for inheriting in self.parent.iterate_to_root():
if inheriting is not self.parent and inheriting._get_property(self.key, raiseerr=False):
@@ -510,14 +504,8 @@ class PropertyLoader(StrategizedProperty):
"can cause dependency issues during flush") %
(self.key, self.parent, inheriting))
- if self.association is not None:
- if isinstance(self.association, type):
- self.association = mapper.class_mapper(self.association, entity_name=self.entity_name, compile=False)
-
self.target = self.mapper.mapped_table
- self.select_mapper = self.mapper.get_select_mapper()
- self.select_table = self.mapper.select_table
- self.loads_polymorphic = self.target is not self.select_table
+ self.table = self.mapper.mapped_table
if self.cascade.delete_orphan:
if self.parent.class_ is self.mapper.class_:
@@ -551,17 +539,6 @@ class PropertyLoader(StrategizedProperty):
except exceptions.ArgumentError, e:
raise exceptions.ArgumentError("""Error determining primary and/or secondary join for relationship '%s'. If the underlying error cannot be corrected, you should specify the 'primaryjoin' (and 'secondaryjoin', if there is an association table present) keyword arguments to the relation() function (or for backrefs, by specifying the backref using the backref() function with keyword arguments) to explicitly specify the join conditions. Nested error is \"%s\"""" % (str(self), str(e)))
- # if using polymorphic mapping, the join conditions must be agasint the base tables of the mappers,
- # as the loader strategies expect to be working with those now (they will adapt the join conditions
- # to the "polymorphic" selectable as needed). since this is an API change, put an explicit check/
- # error message in case its the "old" way.
- if self.loads_polymorphic:
- vis = ColumnsInClause(self.mapper.select_table)
- vis.traverse(self.primaryjoin)
- if self.secondaryjoin:
- vis.traverse(self.secondaryjoin)
- if vis.result:
- raise exceptions.ArgumentError("In relationship '%s', primary and secondary join conditions must not include columns from the polymorphic 'select_table' argument as of SA release 0.3.4. Construct join conditions using the base tables of the related mappers." % (str(self)))
def _col_is_part_of_mappings(self, column):
if self.secondary is None:
@@ -576,8 +553,9 @@ class PropertyLoader(StrategizedProperty):
if self._legacy_foreignkey and not self._refers_to_parent_table():
self.foreign_keys = self._legacy_foreignkey
+ self._opposite_side = util.Set()
+
if self.foreign_keys:
- self._opposite_side = util.Set()
def visit_binary(binary):
if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
@@ -585,12 +563,8 @@ class PropertyLoader(StrategizedProperty):
self._opposite_side.add(binary.right)
if binary.right in self.foreign_keys:
self._opposite_side.add(binary.left)
- visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
- if self.secondaryjoin is not None:
- visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
else:
self.foreign_keys = util.Set()
- self._opposite_side = util.Set()
def visit_binary(binary):
if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
@@ -609,16 +583,18 @@ class PropertyLoader(StrategizedProperty):
if f.references(binary.left.table):
self.foreign_keys.add(binary.right)
self._opposite_side.add(binary.left)
- visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
- if len(self.foreign_keys) == 0:
- raise exceptions.ArgumentError(
- "Can't locate any foreign key columns in primary join "
- "condition '%s' for relationship '%s'. Specify "
- "'foreign_keys' argument to indicate which columns in "
- "the join condition are foreign." %(str(self.primaryjoin), str(self)))
- if self.secondaryjoin is not None:
- visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
+ visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
+
+ if not self.foreign_keys:
+ raise exceptions.ArgumentError(
+ "Can't locate any foreign key columns in primary join "
+ "condition '%s' for relationship '%s'. Specify "
+ "'foreign_keys' argument to indicate which columns in "
+ "the join condition are foreign." %(str(self.primaryjoin), str(self)))
+
+ if self.secondaryjoin is not None:
+ visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
def _determine_direction(self):
@@ -650,8 +626,8 @@ class PropertyLoader(StrategizedProperty):
self.direction = sync.ONETOMANY
else:
for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]:
- onetomany = len([c for c in self.foreign_keys if mappedtable.c.contains_column(c)])
- manytoone = len([c for c in self.foreign_keys if parenttable.c.contains_column(c)])
+ onetomany = [c for c in self.foreign_keys if mappedtable.c.contains_column(c)]
+ manytoone = [c for c in self.foreign_keys if parenttable.c.contains_column(c)]
if not onetomany and not manytoone:
raise exceptions.ArgumentError(
@@ -682,52 +658,10 @@ class PropertyLoader(StrategizedProperty):
self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side)
- def _create_polymorphic_joins(self):
- # get ready to create "polymorphic" primary/secondary join clauses.
- # these clauses represent the same join between parent/child tables that the primary
- # and secondary join clauses represent, except they reference ColumnElements that are specifically
- # in the "polymorphic" selectables. these are used to construct joins for both Query as well as
- # eager loading, and also are used to calculate "lazy loading" clauses.
-
- if self.loads_polymorphic:
-
- # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
- # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
- # several "equivalent" columns (such as parent/child fk cols) into just one column.
- target_equivalents = self.mapper._equivalent_columns
-
- if self.secondaryjoin:
- self.polymorphic_secondaryjoin = ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
- self.polymorphic_primaryjoin = self.primaryjoin
- else:
- if self.direction is sync.ONETOMANY:
- self.polymorphic_primaryjoin = ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
- elif self.direction is sync.MANYTOONE:
- self.polymorphic_primaryjoin = ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
- self.polymorphic_secondaryjoin = None
-
- # load "polymorphic" versions of the columns present in "remote_side" - this is
- # important for lazy-clause generation which goes off the polymorphic target selectable
- for c in list(self.remote_side):
- if self.secondary and self.secondary.columns.contains_column(c):
- continue
- for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []):
- corr = self.mapper.select_table.corresponding_column(equiv)
- if corr:
- self.remote_side.add(corr)
- break
- else:
- raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table))
- else:
- self.polymorphic_primaryjoin = self.primaryjoin
- self.polymorphic_secondaryjoin = self.secondaryjoin
-
def _post_init(self):
if logging.is_info_enabled(self.logger):
self.logger.info(str(self) + " setup primary join " + str(self.primaryjoin))
- self.logger.info(str(self) + " setup polymorphic primary join " + str(self.polymorphic_primaryjoin))
self.logger.info(str(self) + " setup secondary join " + str(self.secondaryjoin))
- self.logger.info(str(self) + " setup polymorphic secondary join " + str(self.polymorphic_secondaryjoin))
self.logger.info(str(self) + " foreign keys " + str([str(c) for c in self.foreign_keys]))
self.logger.info(str(self) + " remote columns " + str([str(c) for c in self.remote_side]))
self.logger.info(str(self) + " relation direction " + (self.direction is sync.ONETOMANY and "one-to-many" or (self.direction is sync.MANYTOONE and "many-to-one" or "many-to-many")))
@@ -756,41 +690,29 @@ class PropertyLoader(StrategizedProperty):
super(PropertyLoader, self).do_init()
def _refers_to_parent_table(self):
- return self.parent.mapped_table is self.target or self.parent.select_table is self.target
+ return self.parent.mapped_table is self.target or self.parent.mapped_table is self.target
def _is_self_referential(self):
return self.mapper.common_parent(self.parent)
def primary_join_against(self, mapper, selectable=None, toselectable=None):
- return self.__cached_join_against(mapper, selectable, toselectable, True, False)
+ return self.__join_against(mapper, selectable, toselectable, True, False)
def secondary_join_against(self, mapper, toselectable=None):
- return self.__cached_join_against(mapper, None, toselectable, False, True)
+ return self.__join_against(mapper, None, toselectable, False, True)
def full_join_against(self, mapper, selectable=None, toselectable=None):
- return self.__cached_join_against(mapper, selectable, toselectable, True, True)
+ return self.__join_against(mapper, selectable, toselectable, True, True)
- def __cached_join_against(self, frommapper, fromselectable, toselectable, primary, secondary):
+ def __join_against(self, frommapper, fromselectable, toselectable, primary, secondary):
if fromselectable is None:
fromselectable = frommapper.local_table
- try:
- rec = self.__parent_join_cache[fromselectable]
- except KeyError:
- self.__parent_join_cache[fromselectable] = rec = {}
-
- key = (frommapper, primary, secondary, toselectable)
- if key in rec:
- return rec[key]
-
parent_equivalents = frommapper._equivalent_columns
if primary:
- if toselectable:
- primaryjoin = self.primaryjoin
- else:
- primaryjoin = self.polymorphic_primaryjoin
-
+ primaryjoin = self.primaryjoin
+
if fromselectable is not frommapper.local_table:
if self.direction is sync.ONETOMANY:
primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
@@ -800,22 +722,12 @@ class PropertyLoader(StrategizedProperty):
primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
if secondary:
- if toselectable:
- secondaryjoin = self.secondaryjoin
- else:
- secondaryjoin = self.polymorphic_secondaryjoin
- rec[key] = ret = primaryjoin & secondaryjoin
+ secondaryjoin = self.secondaryjoin
+ return primaryjoin & secondaryjoin
else:
- rec[key] = ret = primaryjoin
- return ret
-
+ return primaryjoin
elif secondary:
- if toselectable:
- rec[key] = ret = self.secondaryjoin
- else:
- rec[key] = ret = self.polymorphic_secondaryjoin
- return ret
-
+ return self.secondaryjoin
else:
raise AssertionError("illegal condition")
@@ -823,9 +735,9 @@ class PropertyLoader(StrategizedProperty):
"""deprecated. use primary_join_against(), secondary_join_against(), full_join_against()"""
if primary and secondary:
- return self.full_join_against(parent, parent.select_table)
+ return self.full_join_against(parent, parent.mapped_table)
elif primary:
- return self.primary_join_against(parent, parent.select_table)
+ return self.primary_join_against(parent, parent.mapped_table)
elif secondary:
return self.secondary_join_against(parent)
else:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index ad8ec33f8..f39ef87c8 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -28,14 +28,13 @@ from sqlalchemy.orm import interfaces
__all__ = ['Query', 'QueryContext']
-
+
class Query(object):
"""Encapsulates the object-fetching operations provided by Mappers."""
def __init__(self, class_or_mapper, session=None, entity_name=None):
- self._init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name))
self._session = session
-
+
self._with_options = []
self._lockmode = None
@@ -52,8 +51,6 @@ class Query(object):
self._joinable_tables = None
self._having = None
self._column_aggregate = None
- self._aliases = None
- self._alias_ids = {}
self._populate_existing = False
self._version_check = False
self._autoflush = True
@@ -62,49 +59,89 @@ class Query(object):
self._current_path = ()
self._only_load_props = None
self._refresh_instance = None
-
- def _init_mapper(self, mapper, select_mapper=None):
+
+ self._init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name))
+
+ def _init_mapper(self, mapper):
"""populate all instance variables derived from this Query's mapper."""
self.mapper = mapper
- self.select_mapper = select_mapper or self.mapper.get_select_mapper().compile()
- self.table = self._from_obj = self.select_mapper.mapped_table
+ self.table = self._from_obj = self.mapper.mapped_table
self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
self._extension = self.mapper.extension
- self._adapter = self.select_mapper._clause_adapter
+ self._aliases_head = self._aliases_tail = None
+ self._alias_ids = {}
self._joinpoint = self.mapper
- self._with_polymorphic = []
-
+ self._entities.append(_PrimaryMapperEntity(self.mapper))
+ if self.mapper.with_polymorphic:
+ self._set_with_polymorphic(*self.mapper.with_polymorphic)
+ else:
+ self._with_polymorphic = []
+
+ def _generate_alias_ids(self):
+ self._alias_ids = dict([
+ (k, list(v)) for k, v in self._alias_ids.iteritems()
+ ])
+
def _no_criterion(self, meth):
return self._conditional_clone(meth, [self._no_criterion_condition])
def _no_statement(self, meth):
return self._conditional_clone(meth, [self._no_statement_condition])
-
- def _new_base_mapper(self, mapper, meth):
+
+ def _reset_all(self, mapper, meth):
q = self._conditional_clone(meth, [self._no_criterion_condition])
q._init_mapper(mapper, mapper)
return q
+
+ def _set_select_from(self, from_obj):
+ if isinstance(from_obj, expression._SelectBaseMixin):
+ # alias SELECTs and unions
+ from_obj = from_obj.alias()
+
+ self._from_obj = from_obj
+ self._alias_ids = {}
+ if self.table not in self._get_joinable_tables():
+ self._aliases_head = self._aliases_tail = mapperutil.AliasedClauses(self._from_obj, equivalents=self.mapper._equivalent_columns)
+ self._alias_ids.setdefault(self.table, []).append(self._aliases_head)
+ else:
+ self._aliases_head = self._aliases_tail = None
+
+ def _set_with_polymorphic(self, cls_or_mappers, selectable=None):
+ mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable)
+ self._with_polymorphic = mappers
+ self._set_select_from(from_obj)
+
def _no_criterion_condition(self, q, meth):
- if q._criterion or q._statement or q._from_obj is not self.table:
+ if q._criterion or q._statement:
util.warn(
("Query.%s() being called on a Query with existing criterion; "
"criterion is being ignored.") % meth)
- q._from_obj = self.table
- q._adapter = self.select_mapper._clause_adapter
- q._alias_ids = {}
q._joinpoint = self.mapper
- q._statement = q._aliases = q._criterion = None
+ q._statement = q._criterion = None
q._order_by = q._group_by = q._distinct = False
-
+ q._aliases_tail = q._aliases_head
+ q.table = q._from_obj = q.mapper.mapped_table
+ if q.mapper.with_polymorphic:
+ q._set_with_polymorphic(*q.mapper.with_polymorphic)
+
+ def _no_entities(self, meth):
+ q = self._no_statement(meth)
+ if len(q._entities) > 1 and not isinstance(q._entities[0], _PrimaryMapperEntity):
+ raise exceptions.InvalidRequestError(
+ ("Query.%s() being called on a Query with existing "
+ "additional entities or columns - can't replace columns") % meth)
+ q._entities = []
+ return q
+
def _no_statement_condition(self, q, meth):
if q._statement:
raise exceptions.InvalidRequestError(
("Query.%s() being called on a Query with an existing full "
"statement - can't apply criterion.") % meth)
-
+
def _conditional_clone(self, methname=None, conditions=None):
q = self._clone()
if conditions:
@@ -117,20 +154,18 @@ class Query(object):
q.__dict__ = self.__dict__.copy()
return q
- def _get_session(self):
+ def session(self):
if self._session is None:
return self.mapper.get_session()
else:
return self._session
-
- primary_key_columns = property(lambda s:s.select_mapper.primary_key)
- session = property(_get_session)
+ session = property(session)
def _with_current_path(self, path):
q = self._clone()
q._current_path = path
return q
-
+
def with_polymorphic(self, cls_or_mappers, selectable=None):
"""Load columns for descendant mappers of this Query's mapper.
@@ -140,12 +175,6 @@ class Query(object):
instances will also have those columns already loaded so that
no "post fetch" of those columns will be required.
- If this Query's mapper has a ``select_table`` argument,
- with_polymorphic() overrides it; the FROM clause will be against
- the local table of the base mapper outer joined with the local
- tables of each specified descendant mapper (unless ``selectable``
- is specified).
-
``cls_or_mappers`` is a single class or mapper, or list of class/mappers,
which inherit from this Query's mapper. Alternatively, it
may also be the string ``'*'``, in which case all descending
@@ -162,30 +191,12 @@ class Query(object):
clause which will usually lead to incorrect results.
"""
-
- q = self._new_base_mapper(self.mapper, 'with_polymorphic')
+ q = self._no_criterion('with_polymorphic')
- if cls_or_mappers == '*':
- cls_or_mappers = self.mapper.polymorphic_iterator()
- else:
- cls_or_mappers = util.to_list(cls_or_mappers)
-
- if selectable:
- q = q.select_from(selectable)
-
- for cls_or_mapper in cls_or_mappers:
- poly_mapper = _class_to_mapper(cls_or_mapper)
- if poly_mapper is self.mapper:
- continue
-
- q._with_polymorphic.append(poly_mapper)
- if not selectable:
- if poly_mapper.concrete:
- raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.")
- elif not poly_mapper.single:
- q._from_obj = q._from_obj.outerjoin(poly_mapper.local_table, poly_mapper.inherit_condition)
+ q._set_with_polymorphic(cls_or_mappers, selectable=selectable)
return q
+
def yield_per(self, count):
"""Yield only ``count`` rows at a time.
@@ -345,9 +356,22 @@ class Query(object):
if isinstance(entity, type):
entity = mapper.class_mapper(entity)
if alias is not None:
- alias = mapperutil.AliasedClauses(entity.mapped_table, alias=alias)
+ alias = mapperutil.AliasedClauses(alias)
- q._entities = q._entities + [(entity, alias, id)]
+ q._entities = q._entities + [_MapperEntity(mapper=entity, alias=alias, id=id)]
+ return q
+
+ def _values(self, *columns):
+ """Turn this query into a 'columns only' query.
+
+ The API for this method hasn't been decided yet and is subject to change.
+ """
+
+ q = self._no_entities('_values')
+ q._only_load_props = q._eager_loaders = util.Set()
+
+ for column in columns:
+ q._entities.append(self._add_column(column, None))
return q
def add_column(self, column, id=None):
@@ -370,18 +394,18 @@ class Query(object):
"""
q = self._clone()
-
- # duck type to get a ClauseElement
- if hasattr(column, 'clause_element'):
+ q._entities = q._entities + [self._add_column(column, id)]
+ return q
+
+ def _add_column(self, column, id=None):
+ if isinstance(column, interfaces.PropComparator):
column = column.clause_element()
- # alias non-labeled column elements.
- if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
- column = column.label(None)
-
- q._entities = q._entities + [(column, None, id)]
- return q
+ elif not isinstance(column, (sql.ColumnElement, basestring)):
+ raise exceptions.InvalidRequestError("Invalid column expression '%r'" % column)
+ return _ColumnEntity(column=column, id=id)
+
def options(self, *args):
"""Return a new Query object, applying the given list of
MapperOptions.
@@ -409,6 +433,7 @@ class Query(object):
def with_lockmode(self, mode):
"""Return a new Query object with the specified locking mode."""
+
q = self._clone()
q._lockmode = mode
return q
@@ -420,12 +445,11 @@ class Query(object):
as the first positional argument. The reason for both is that \**kwargs is
convenient, however some parameter dictionaries contain unicode keys in which case
\**kwargs cannot be used.
- """
+ """
q = self._clone()
if len(args) == 1:
- d = args[0]
- kwargs.update(d)
+ kwargs.update(args[0])
elif len(args) > 0:
raise exceptions.ArgumentError("params() takes zero or one positional argument, which is a dictionary.")
q._params = q._params.copy()
@@ -436,16 +460,16 @@ class Query(object):
"""apply the given filtering criterion to the query and return the newly resulting ``Query``
the criterion is any sql.ClauseElement applicable to the WHERE clause of a select.
- """
+ """
if isinstance(criterion, basestring):
criterion = sql.text(criterion)
if criterion is not None and not isinstance(criterion, sql.ClauseElement):
raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
- if self._adapter is not None:
- criterion = self._adapter.traverse(criterion)
+ if self._aliases_tail:
+ criterion = self._aliases_tail.adapt_clause(criterion)
q = self._no_statement("filter")
if q._criterion is not None:
@@ -462,160 +486,12 @@ class Query(object):
return self.filter(sql.and_(*clauses))
- def _get_joinable_tables(self):
- if not self._joinable_tables or self._joinable_tables[0] is not self._from_obj:
- currenttables = [self._from_obj]
- def visit_join(join):
- currenttables.append(join.left)
- currenttables.append(join.right)
- visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False})
- self._joinable_tables = (self._from_obj, currenttables)
- return currenttables
- else:
- return self._joinable_tables[1]
-
- def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
- if start is None:
- start = self._joinpoint
-
- clause = self._from_obj
-
- currenttables = self._get_joinable_tables()
-
- # determine if generated joins need to be aliased on the left
- # hand side.
- if self._adapter and not self._aliases: # at the beginning of a join, look at leftmost adapter
- adapt_against = self._adapter.selectable
- elif start is self.select_mapper: # or if its our base mapper, go against our base table
- adapt_against = self.table
- elif start.select_table is not start.mapped_table: # in the middle of a join, look for a polymorphic mapper
- adapt_against = start.select_table
- elif self._aliases: # joining against aliases
- adapt_against = self._aliases.alias
- else:
- adapt_against = None
-
- mapper = start
- alias = self._aliases
-
- if not isinstance(keys, list):
- keys = [keys]
- for key in keys:
- use_selectable = None
- of_type = None
-
- if isinstance(key, tuple):
- key, use_selectable = key
-
- if isinstance(key, interfaces.PropComparator):
- prop = key.property
- if getattr(key, '_of_type', None):
- if use_selectable:
- raise exceptions.InvalidRequestError("Can't specify use_selectable along with polymorphic property created via of_type().")
- of_type = key._of_type
- use_selectable = key._of_type.select_table
- else:
- prop = mapper.get_property(key, resolve_synonyms=True)
-
- if use_selectable:
- if not use_selectable.is_derived_from(prop.mapper.mapped_table):
- raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description))
- if not isinstance(use_selectable, expression.Alias):
- use_selectable = use_selectable.alias()
-
- if prop._is_self_referential() and not create_aliases and not use_selectable:
- raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires aliased=True argument." % str(prop))
-
- if prop.select_table not in currenttables or create_aliases or use_selectable:
- if prop.secondary:
- if use_selectable or create_aliases:
- alias = mapperutil.PropertyAliasedClauses(prop,
- prop.primary_join_against(mapper, adapt_against),
- prop.secondary_join_against(mapper, toselectable=use_selectable),
- alias,
- alias=use_selectable
- )
- crit = alias.primaryjoin
- clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
- else:
- crit = prop.primary_join_against(mapper, adapt_against)
- clause = clause.join(prop.secondary, crit, isouter=outerjoin)
- clause = clause.join(prop.select_table, prop.secondary_join_against(mapper), isouter=outerjoin)
- else:
- if use_selectable or create_aliases:
- alias = mapperutil.PropertyAliasedClauses(prop,
- prop.primary_join_against(mapper, adapt_against, toselectable=use_selectable),
- None,
- alias,
- alias=use_selectable
- )
- crit = alias.primaryjoin
- clause = clause.join(alias.alias, crit, isouter=outerjoin)
- else:
- crit = prop.primary_join_against(mapper, adapt_against)
- clause = clause.join(prop.select_table, crit, isouter=outerjoin)
- elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables:
- # TODO: this check is not strong enough for different paths to the same endpoint which
- # does not use secondary tables
- raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key)
-
- mapper = of_type or prop.mapper
-
- if use_selectable:
- adapt_against = use_selectable
- elif mapper.select_table is not mapper.mapped_table:
- adapt_against = mapper.select_table
-
- return (clause, mapper, alias)
-
- def _generative_col_aggregate(self, col, func):
- """apply the given aggregate function to the query and return the newly
- resulting ``Query``.
- """
- if self._column_aggregate is not None:
- raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
- q = self._no_statement("aggregate")
- q._column_aggregate = (col, func)
- return q
-
- def apply_min(self, col):
- """apply the SQL ``min()`` function against the given column to the
- query and return the newly resulting ``Query``.
-
- DEPRECATED.
- """
- return self._generative_col_aggregate(col, sql.func.min)
-
- def apply_max(self, col):
- """apply the SQL ``max()`` function against the given column to the
- query and return the newly resulting ``Query``.
-
- DEPRECATED.
- """
- return self._generative_col_aggregate(col, sql.func.max)
-
- def apply_sum(self, col):
- """apply the SQL ``sum()`` function against the given column to the
- query and return the newly resulting ``Query``.
-
- DEPRECATED.
- """
- return self._generative_col_aggregate(col, sql.func.sum)
-
- def apply_avg(self, col):
- """apply the SQL ``avg()`` function against the given column to the
- query and return the newly resulting ``Query``.
-
- DEPRECATED.
- """
- return self._generative_col_aggregate(col, sql.func.avg)
-
def _col_aggregate(self, col, func):
"""Execute ``func()`` function against the given column.
For performance, only use subselect if `order_by` attribute is set.
- """
+ """
ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj}
if self._autoflush and not self._populate_existing:
@@ -647,31 +523,31 @@ class Query(object):
return self._col_aggregate(col, sql.func.avg)
- def order_by(self, criterion):
+ def order_by(self, *criterion):
"""apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
q = self._no_statement("order_by")
- if self._adapter:
- criterion = [expression._literal_as_text(o) for o in util.to_list(criterion) or []]
- criterion = self._adapter.copy_and_process(criterion)
+ if self._aliases_tail:
+ criterion = [expression._literal_as_text(o) for o in util.starargs_as_list(*criterion)]
+ criterion = self._aliases_tail.adapt_list(criterion)
if q._order_by is False:
- q._order_by = util.to_list(criterion)
+ q._order_by = util.starargs_as_list(*criterion)
else:
- q._order_by = q._order_by + util.to_list(criterion)
+ q._order_by = q._order_by + util.starargs_as_list(*criterion)
return q
- def group_by(self, criterion):
+ def group_by(self, *criterion):
"""apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""
q = self._no_statement("group_by")
if q._group_by is False:
- q._group_by = util.to_list(criterion)
+ q._group_by = util.starargs_as_list(*criterion)
else:
- q._group_by = q._group_by + util.to_list(criterion)
+ q._group_by = q._group_by + util.starargs_as_list(*criterion)
return q
-
+
def having(self, criterion):
"""apply a HAVING criterion to the query and return the newly resulting ``Query``."""
@@ -681,8 +557,8 @@ class Query(object):
if criterion is not None and not isinstance(criterion, sql.ClauseElement):
raise exceptions.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string")
- if self._adapter is not None:
- criterion = self._adapter.traverse(criterion)
+ if self._aliases_tail:
+ criterion = self._aliases_tail.adapt_clause(criterion)
q = self._no_statement("having")
if q._having is not None:
@@ -733,39 +609,138 @@ class Query(object):
"""
return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint)
-
+
def _join(self, prop, id, outerjoin, aliased, from_joinpoint):
(clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased)
+ # TODO: improve the generative check here to look for primary mapped entity, etc.
q = self._no_statement("join")
q._from_obj = clause
q._joinpoint = mapper
q._aliases = aliases
-
- q._alias_ids = {}
- for k, v in self._alias_ids.items():
- if isinstance(v, list):
- q._alias_ids[k] = list(v)
- else:
- q._alias_ids[k] = v
+ q._generate_alias_ids()
if aliases:
- q._adapter = sql_util.ClauseAdapter(aliases.alias).copy_and_chain(q._adapter)
- else:
- select_mapper = mapper.get_select_mapper()
- if select_mapper._clause_adapter:
- q._adapter = select_mapper._clause_adapter.copy_and_chain(q._adapter)
+ q._aliases_tail = aliases
a = aliases
while a is not None:
- q._alias_ids.setdefault(a.mapper, []).append(a)
- q._alias_ids.setdefault(a.table, []).append(a)
- q._alias_ids.setdefault(a.alias, []).append(a)
- a = a.parentclauses
+ if isinstance(a, mapperutil.PropertyAliasedClauses):
+ q._alias_ids.setdefault(a.mapper, []).append(a)
+ q._alias_ids.setdefault(a.table, []).append(a)
+ q._alias_ids.setdefault(a.alias, []).append(a)
+ a = a.parentclauses
+ else:
+ break
if id:
- q._alias_ids[id] = aliases
+ q._alias_ids[id] = [aliases]
return q
+ def _get_joinable_tables(self):
+ if not self._joinable_tables or self._joinable_tables[0] is not self._from_obj:
+ currenttables = [self._from_obj]
+ def visit_join(join):
+ currenttables.append(join.left)
+ currenttables.append(join.right)
+ visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False})
+ self._joinable_tables = (self._from_obj, currenttables)
+ return currenttables
+ else:
+ return self._joinable_tables[1]
+
+ def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
+ if start is None:
+ start = self._joinpoint
+
+ clause = self._from_obj
+
+ currenttables = self._get_joinable_tables()
+
+ # determine if generated joins need to be aliased on the left
+ # hand side.
+ if self._aliases_head is self._aliases_tail is not None:
+ adapt_against = self._aliases_tail.alias
+ elif start is not self.mapper and self._aliases_tail:
+ adapt_against = self._aliases_tail.alias
+ else:
+ adapt_against = None
+
+ mapper = start
+ alias = self._aliases_tail
+
+ if not isinstance(keys, list):
+ keys = [keys]
+ for key in keys:
+ use_selectable = None
+ of_type = None
+
+ if isinstance(key, tuple):
+ key, use_selectable = key
+
+ if isinstance(key, interfaces.PropComparator):
+ prop = key.property
+ if getattr(key, '_of_type', None):
+ if use_selectable:
+ raise exceptions.InvalidRequestError("Can't specify use_selectable along with polymorphic property created via of_type().")
+ of_type = key._of_type
+ use_selectable = key._of_type.mapped_table
+ else:
+ prop = mapper.get_property(key, resolve_synonyms=True)
+
+ if use_selectable:
+ if not use_selectable.is_derived_from(prop.mapper.mapped_table):
+ raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description))
+ if not isinstance(use_selectable, expression.Alias):
+ use_selectable = use_selectable.alias()
+ elif prop.mapper.with_polymorphic:
+ use_selectable = prop.mapper._with_polymorphic_selectable()
+ if not isinstance(use_selectable, expression.Alias):
+ use_selectable = use_selectable.alias()
+
+ if prop._is_self_referential() and not create_aliases and not use_selectable:
+ raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires aliased=True argument." % str(prop))
+
+ if prop.table not in currenttables or create_aliases or use_selectable:
+ if prop.secondary:
+ if use_selectable or create_aliases:
+ alias = mapperutil.PropertyAliasedClauses(prop,
+ prop.primary_join_against(mapper, adapt_against),
+ prop.secondary_join_against(mapper, toselectable=use_selectable),
+ alias,
+ alias=use_selectable
+ )
+ crit = alias.primaryjoin
+ clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
+ else:
+ crit = prop.primary_join_against(mapper, adapt_against)
+ clause = clause.join(prop.secondary, crit, isouter=outerjoin)
+ clause = clause.join(prop.table, prop.secondary_join_against(mapper), isouter=outerjoin)
+ else:
+ if use_selectable or create_aliases:
+ alias = mapperutil.PropertyAliasedClauses(prop,
+ prop.primary_join_against(mapper, adapt_against, toselectable=use_selectable),
+ None,
+ alias,
+ alias=use_selectable
+ )
+ crit = alias.primaryjoin
+ clause = clause.join(alias.alias, crit, isouter=outerjoin)
+ else:
+ crit = prop.primary_join_against(mapper, adapt_against)
+ clause = clause.join(prop.table, crit, isouter=outerjoin)
+ elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables:
+ # TODO: this check is not strong enough for different paths to the same endpoint which
+ # does not use secondary tables
+ raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key)
+
+ mapper = of_type or prop.mapper
+
+ if use_selectable:
+ adapt_against = use_selectable
+
+ return (clause, mapper, alias)
+
+
def reset_joinpoint(self):
"""return a new Query reset the 'joinpoint' of this Query reset
back to the starting mapper. Subsequent generative calls will
@@ -777,12 +752,12 @@ class Query(object):
q = self._no_statement("reset_joinpoint")
q._joinpoint = q.mapper
- q._aliases = None
if q.table not in q._get_joinable_tables():
- q._adapter = sql_util.ClauseAdapter(q._from_obj, equivalents=q.mapper._equivalent_columns)
+ q._aliases_head = q._aliases_tail = mapperutil.AliasedClauses(q._from_obj, equivalents=q.mapper._equivalent_columns)
+ else:
+ q._aliases_head = q._aliases_tail = None
return q
-
def select_from(self, from_obj):
"""Set the `from_obj` parameter of the query and return the newly
resulting ``Query``. This replaces the table which this Query selects
@@ -797,16 +772,9 @@ class Query(object):
util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.")
from_obj = from_obj[-1]
- if isinstance(from_obj, expression._SelectBaseMixin):
- # alias SELECTs and unions
- from_obj = from_obj.alias()
-
- new._from_obj = from_obj
-
- if new.table not in new._get_joinable_tables():
- new._adapter = sql_util.ClauseAdapter(new._from_obj, equivalents=new.mapper._equivalent_columns)
+ new._set_select_from(from_obj)
return new
-
+
def __getitem__(self, item):
if isinstance(item, slice):
start = item.start
@@ -938,55 +906,9 @@ class Query(object):
context.runid = _new_runid()
- # for with_polymorphic, instruct descendant mappers that they
- # don't need to post-fetch anything
- for m in self._with_polymorphic:
- context.attributes[('polymorphic_fetch', m)] = (self.select_mapper, [])
-
- mappers_or_columns = tuple(self._entities) + mappers_or_columns
- tuples = bool(mappers_or_columns)
-
- if context.row_adapter:
- def main(context, row):
- return self.select_mapper._instance(context, context.row_adapter(row), None,
- extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
- )
- else:
- def main(context, row):
- return self.select_mapper._instance(context, row, None,
- extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
- )
-
- if tuples:
- process = []
- process.append(main)
- for tup in mappers_or_columns:
- if isinstance(tup, tuple):
- (m, alias, alias_id) = tup
- clauses = self._get_entity_clauses(tup)
- else:
- clauses = alias = alias_id = None
- m = tup
-
- if isinstance(m, type):
- m = mapper.class_mapper(m)
-
- if isinstance(m, mapper.Mapper):
- def x(m):
- row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row)
- def proc(context, row):
- return m._instance(context, row_adapter(row), None)
- process.append(proc)
- x(m)
- elif isinstance(m, (sql.ColumnElement, basestring)):
- def y(m):
- row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row)
- def proc(context, row):
- return row_adapter(row)[m]
- process.append(proc)
- y(m)
- else:
- raise exceptions.InvalidRequestError("Invalid column expression '%r'" % m)
+ entities = self._entities + [_QueryEntity.legacy_guess_type(mc) for mc in mappers_or_columns]
+ should_unique = isinstance(entities[0], _PrimaryMapperEntity) and len(entities) == 1
+ process = [query_entity.row_processor(self, context) for query_entity in entities]
while True:
context.progress = util.Set()
@@ -999,14 +921,14 @@ class Query(object):
else:
fetch = cursor.fetchall()
- if tuples:
+ if not should_unique:
rows = util.OrderedSet()
for row in fetch:
rows.add(tuple([proc(context, row) for proc in process]))
else:
rows = util.UniqueAppender([])
for row in fetch:
- rows.append(main(context, row))
+ rows.append(process[0](context, row))
if context.refresh_instance and context.only_load_props and context.refresh_instance in context.progress:
context.refresh_instance.commit(context.only_load_props)
@@ -1043,19 +965,19 @@ class Query(object):
q = self
# dont use 'polymorphic' mapper if we are refreshing an instance
- if refresh_instance and q.select_mapper is not q.mapper:
- q = q._new_base_mapper(q.mapper, '_get')
+ if refresh_instance and q.mapper is not q.mapper:
+ q = q._reset_all(q.mapper, '_get')
if ident is not None:
q = q._no_criterion('get')
params = {}
- (_get_clause, _get_params) = q.select_mapper._get_clause
+ (_get_clause, _get_params) = q.mapper._get_clause
q = q.filter(_get_clause)
- for i, primary_key in enumerate(q.primary_key_columns):
+ for i, primary_key in enumerate(q.mapper.primary_key):
try:
params[_get_params[primary_key].key] = ident[i]
except IndexError:
- raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
+ raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in q.mapper.primary_key]))
q = q.params(params)
if lockmode is not None:
@@ -1068,10 +990,14 @@ class Query(object):
except IndexError:
return None
- def _nestable(self, **kwargs):
- """Return true if the given statement options imply it should be nested."""
-
+ def _select_args(self):
+ return {'limit':self._limit, 'offset':self._offset, 'distinct':self._distinct, 'group_by':self._group_by or None, 'having':self._having or None}
+ _select_args = property(_select_args)
+
+ def _should_nest_selectable(self):
+ kwargs = self._select_args
return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False))
+ _should_nest_selectable = property(_should_nest_selectable)
def count(self, whereclause=None, params=None, **kwargs):
"""Apply this query's criterion to a SELECT COUNT statement.
@@ -1100,11 +1026,11 @@ class Query(object):
context = QueryContext(self)
from_obj = self._from_obj
- if self._nestable(**self._select_args()):
- s = sql.select([self.table], whereclause, from_obj=from_obj, **self._select_args()).alias('getcount').count()
+ if self._should_nest_selectable:
+ s = sql.select([self.table], whereclause, from_obj=from_obj, **self._select_args).alias('getcount').count()
else:
- primary_key = self.primary_key_columns
- s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **self._select_args())
+ primary_key = self.mapper.primary_key
+ s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **self._select_args)
if self._autoflush and not self._populate_existing:
self.session._autoflush()
return self.session.scalar(s, params=self._params, mapper=self.mapper)
@@ -1112,7 +1038,7 @@ class Query(object):
def compile(self):
"""compiles and returns a SQL statement based on the criterion and conditions within this Query."""
return self._compile_context().statement
-
+
def _compile_context(self):
context = QueryContext(self)
@@ -1122,18 +1048,9 @@ class Query(object):
context.statement = self._statement
return context
- whereclause = self._criterion
from_obj = self._from_obj
- adapter = self._adapter
- order_by = self._order_by
-
- if order_by is False:
- order_by = self.select_mapper.order_by
- if order_by is False:
- order_by = from_obj.default_order_by()
- if order_by is None:
- order_by = self.table.default_order_by()
-
+ adapter = self._aliases_head
+
if self._lockmode:
try:
for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
@@ -1142,152 +1059,86 @@ class Query(object):
else:
for_update = False
- # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
- # that we only load the appropriate types
- if self.select_mapper.single and self.select_mapper.inherits is not None and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None:
- whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()]))
-
context.from_clause = from_obj
-
- # TODO: compile eagerloads from select_mapper if polymorphic ? [ticket:917]
- if self._with_polymorphic:
- props = util.Set()
- for m in [self.select_mapper] + self._with_polymorphic:
- for value in m.iterate_properties:
- props.add(value)
- else:
- props = self.select_mapper.iterate_properties
+ context.whereclause = self._criterion
+ context.order_by = self._order_by
+
+ for entity in self._entities:
+ entity.setup_context(self, context)
- for value in props:
- if self._only_load_props and value.key not in self._only_load_props:
- continue
- context.exec_with_path(self.select_mapper, value.key, value.setup, context, only_load_props=self._only_load_props)
-
- # additional entities/columns, add those to selection criterion
- for tup in self._entities:
- (m, alias, alias_id) = tup
- clauses = self._get_entity_clauses(tup)
- if isinstance(m, mapper.Mapper):
- for value in m.iterate_properties:
- context.exec_with_path(m, value.key, value.setup, context, parentclauses=clauses)
- elif isinstance(m, sql.ColumnElement):
- if clauses is not None:
- m = clauses.aliased_column(m)
- context.secondary_columns.append(m)
-
- if self._eager_loaders and self._nestable(**self._select_args()):
+ if self._eager_loaders and self._should_nest_selectable:
# eager loaders are present, and the SELECT has limiting criterion
# produce a "wrapped" selectable.
-
- # locate all embedded Column clauses so they can be added to the
- # "inner" select statement where they'll be available to the enclosing
- # statement's "order by"
- cf = util.Set()
- if order_by:
- order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
- for o in order_by:
- cf.update(sql_util.find_columns(o))
-
+
+ if context.order_by:
+ context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []]
+ if adapter:
+ context.order_by = adapter.adapt_list(context.order_by)
+ # locate all embedded Column clauses so they can be added to the
+ # "inner" select statement where they'll be available to the enclosing
+ # statement's "order by"
+ # TODO: this likely doesn't work with very involved ORDER BY expressions,
+ # such as those including subqueries
+ order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by]))
+ else:
+ context.order_by = None
+ order_by_col_expr = []
+
if adapter:
- # TODO: make usage of the ClauseAdapter here to create the list
- # of primary columns ?
- context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
- cf = [from_obj.corresponding_column(c) or c for c in cf]
-
- s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args())
-
- s3 = s2.alias()
+ context.primary_columns = adapter.adapt_list(context.primary_columns)
+
+ inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=context.order_by, **self._select_args).alias()
+ local_adapter = sql_util.ClauseAdapter(inner)
- context.row_adapter = mapperutil.create_row_adapter(s3, self.table)
+ context.row_adapter = mapperutil.create_row_adapter(inner, equivalent_columns=self.mapper._equivalent_columns)
- statement = sql.select([s3] + context.secondary_columns, for_update=for_update, use_labels=True)
+ statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=True)
if context.eager_joins:
- eager_joins = sql_util.ClauseAdapter(s3).traverse(context.eager_joins)
+ eager_joins = local_adapter.traverse(context.eager_joins)
statement.append_from(eager_joins, _copy_collection=False)
- if order_by:
- statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
+ if context.order_by:
+ statement.append_order_by(*local_adapter.copy_and_process(context.order_by))
statement.append_order_by(*context.eager_order_by)
else:
- order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
+ if context.order_by:
+ context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []]
+ if adapter:
+ context.order_by = adapter.adapt_list(context.order_by)
+ else:
+ context.order_by = None
if adapter:
- # TODO: make usage of the ClauseAdapter here to create row adapter, list
- # of primary columns ?
- context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
- context.row_adapter = mapperutil.create_row_adapter(from_obj, self.table)
- order_by = adapter.copy_and_process(order_by)
+ context.primary_columns = adapter.adapt_list(context.primary_columns)
+ context.row_adapter = mapperutil.create_row_adapter(adapter.alias, equivalent_columns=self.mapper._equivalent_columns)
- if self._distinct:
-
- if self._distinct and order_by:
- cf = util.Set()
- for o in order_by:
- cf.update(sql_util.find_columns(o))
- for c in cf:
- context.primary_columns.append(c)
+ if self._distinct and context.order_by:
+ order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by]))
+ context.primary_columns += order_by_col_expr
- statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=util.to_list(order_by), **self._select_args())
+ statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=context.order_by, **self._select_args)
if context.eager_joins:
if adapter:
- context.eager_joins = adapter.traverse(context.eager_joins)
+ context.eager_joins = adapter.adapt_clause(context.eager_joins)
statement.append_from(context.eager_joins, _copy_collection=False)
if context.eager_order_by:
if adapter:
- context.eager_order_by = adapter.copy_and_process(context.eager_order_by)
+ context.eager_order_by = adapter.adapt_list(context.eager_order_by)
statement.append_order_by(*context.eager_order_by)
+ # polymorphic mappers which have concrete tables in their hierarchy usually
+ # require row aliasing unconditionally.
+ if not context.row_adapter and self.mapper._requires_row_aliasing:
+ context.row_adapter = mapperutil.create_row_adapter(self.table, equivalent_columns=self.mapper._equivalent_columns)
+
context.statement = statement
return context
- def _select_args(self):
- """Return a dictionary of attributes that can be applied to a ``sql.Select`` statement.
- """
- return {'limit':self._limit, 'offset':self._offset, 'distinct':self._distinct, 'group_by':self._group_by or None, 'having':self._having or None}
-
-
- def _get_entity_clauses(self, m):
- """for tuples added via add_entity() or add_column(), attempt to locate
- an AliasedClauses object which should be used to formulate the query as well
- as to process result rows."""
-
- (m, alias, alias_id) = m
- if alias is not None:
- return alias
- if alias_id is not None:
- try:
- return self._alias_ids[alias_id]
- except KeyError:
- raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id)
-
- if isinstance(m, type):
- m = mapper.class_mapper(m)
- if isinstance(m, mapper.Mapper):
- l = self._alias_ids.get(m)
- if l:
- if len(l) > 1:
- raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_entity()" % str(m))
- else:
- return l[0]
- else:
- return None
- elif isinstance(m, sql.ColumnElement):
- aliases = []
- for table in sql_util.find_tables(m, check_columns=True):
- for a in self._alias_ids.get(table, []):
- aliases.append(a)
- if len(aliases) > 1:
- raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_column()" % str(m))
- elif len(aliases) == 1:
- return aliases[0]
- else:
- return None
-
def __log_debug(self, msg):
self.logger.debug(msg)
@@ -1296,6 +1147,48 @@ class Query(object):
# DEPRECATED LAND !
+ def _generative_col_aggregate(self, col, func):
+ """apply the given aggregate function to the query and return the newly
+ resulting ``Query``. (deprecated)
+ """
+ if self._column_aggregate is not None:
+ raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
+ q = self._no_statement("aggregate")
+ q._column_aggregate = (col, func)
+ return q
+
+ def apply_min(self, col):
+ """apply the SQL ``min()`` function against the given column to the
+ query and return the newly resulting ``Query``.
+
+ DEPRECATED.
+ """
+ return self._generative_col_aggregate(col, sql.func.min)
+
+ def apply_max(self, col):
+ """apply the SQL ``max()`` function against the given column to the
+ query and return the newly resulting ``Query``.
+
+ DEPRECATED.
+ """
+ return self._generative_col_aggregate(col, sql.func.max)
+
+ def apply_sum(self, col):
+ """apply the SQL ``sum()`` function against the given column to the
+ query and return the newly resulting ``Query``.
+
+ DEPRECATED.
+ """
+ return self._generative_col_aggregate(col, sql.func.sum)
+
+ def apply_avg(self, col):
+ """apply the SQL ``avg()`` function against the given column to the
+ query and return the newly resulting ``Query``.
+
+ DEPRECATED.
+ """
+ return self._generative_col_aggregate(col, sql.func.avg)
+
def list(self): #pragma: no cover
"""DEPRECATED. use all()"""
@@ -1521,6 +1414,150 @@ for deprecated_method in ('list', 'scalar', 'count_by',
util.deprecated(getattr(Query, deprecated_method),
add_deprecation_to_docstring=False))
+class _QueryEntity(object):
+ """represent an entity column returned within a Query result."""
+
+ def legacy_guess_type(self, e):
+ if isinstance(e, type):
+ return _MapperEntity(mapper=mapper.class_mapper(e))
+ elif isinstance(e, mapper.Mapper):
+ return _MapperEntity(mapper=e)
+ else:
+ return _ColumnEntity(column=e)
+ legacy_guess_type=classmethod(legacy_guess_type)
+
+class _MapperEntity(_QueryEntity):
+ """entity column corresponding to mapped ORM instances."""
+
+ def __init__(self, mapper, alias=None, id=None):
+ self.mapper = mapper
+ self.alias = alias
+ self.alias_id = id
+
+ def _get_entity_clauses(self, query):
+ if self.alias:
+ return self.alias
+ elif self.alias_id:
+ try:
+ return query._alias_ids[self.alias_id][0]
+ except KeyError:
+ raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id)
+
+ l = query._alias_ids.get(self.mapper)
+ if l:
+ if len(l) > 1:
+ raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_entity()" % str(self.mapper))
+ return l[0]
+ else:
+ return None
+
+ def row_processor(self, query, context):
+ clauses = self._get_entity_clauses(query)
+ if clauses:
+ def proc(context, row):
+ return self.mapper._instance(context, clauses.row_decorator(row), None)
+ else:
+ def proc(context, row):
+ return self.mapper._instance(context, row, None)
+
+ return proc
+
+ def setup_context(self, query, context):
+ clauses = self._get_entity_clauses(query)
+ for value in self.mapper.iterate_properties:
+ context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses)
+
+ def __str__(self):
+ return str(self.mapper)
+
+class _PrimaryMapperEntity(_MapperEntity):
+ """entity column corresponding to the 'primary' (first) mapped ORM instance."""
+
+ def row_processor(self, query, context):
+ if context.row_adapter:
+ def main(context, row):
+ return self.mapper._instance(context, context.row_adapter(row), None,
+ extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
+ )
+ else:
+ def main(context, row):
+ return self.mapper._instance(context, row, None,
+ extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
+ )
+ return main
+
+ def setup_context(self, query, context):
+ # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
+ # that we only load the appropriate types
+ if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
+ context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
+
+ if context.order_by is False:
+ if self.mapper.order_by:
+ context.order_by = self.mapper.order_by
+ elif context.from_clause.default_order_by():
+ context.order_by = context.from_clause.default_order_by()
+
+ for value in self.mapper._iterate_polymorphic_properties(query._with_polymorphic, context.from_clause):
+ if query._only_load_props and value.key not in query._only_load_props:
+ continue
+ context.exec_with_path(self.mapper, value.key, value.setup, context, only_load_props=query._only_load_props)
+
+
+class _ColumnEntity(_QueryEntity):
+ """entity column corresponding to Table or selectable columns."""
+
+ def __init__(self, column, id=None):
+ if column and isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
+ column = column.label(None)
+ self.column = column
+ self.alias_id = id
+ self.__tables = None
+
+ def _tables(self):
+ if not self.__tables:
+ self.__tables = sql_util.find_tables(self.column, check_columns=True)
+ return self.__tables
+ _tables = property(_tables)
+
+ def _get_entity_clauses(self, query):
+ if self.alias_id:
+ try:
+ return query._alias_ids[self.alias_id][0]
+ except KeyError:
+ raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id)
+
+ if isinstance(self.column, sql.ColumnElement):
+ aliases = list(chain(*[query._alias_ids[t] for t in self._tables if t in query._alias_ids]))
+ if len(aliases) > 1:
+ raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_column()" % str(self.column))
+ elif len(aliases) == 1:
+ return aliases[0]
+
+ return None
+
+ def row_processor(self, query, context):
+ clauses = self._get_entity_clauses(query)
+ if clauses:
+ def proc(context, row):
+ return clauses.row_decorator(row)[self.column]
+ else:
+ def proc(context, row):
+ return row[self.column]
+ return proc
+
+ def setup_context(self, query, context):
+ clauses = self._get_entity_clauses(query)
+ if clauses:
+ context.secondary_columns.append(clauses.aliased_column(self.column))
+ else:
+ context.secondary_columns.append(self.column)
+
+ def __str__(self):
+ return str(self.column)
+
+
+
Query.logger = logging.class_logger(Query)
class QueryContext(object):
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index a7ab3d005..57237e08f 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -223,7 +223,7 @@ class UndeferGroupOption(MapperOption):
class AbstractRelationLoader(LoaderStrategy):
def init(self):
super(AbstractRelationLoader, self).init()
- for attr in ['primaryjoin', 'secondaryjoin', 'secondary', 'foreign_keys', 'mapper', 'select_mapper', 'target', 'select_table', 'loads_polymorphic', 'uselist', 'cascade', 'attributeext', 'order_by', 'remote_side', 'polymorphic_primaryjoin', 'polymorphic_secondaryjoin', 'direction']:
+ for attr in ['primaryjoin', 'secondaryjoin', 'secondary', 'foreign_keys', 'mapper', 'target', 'table', 'uselist', 'cascade', 'attributeext', 'order_by', 'remote_side', 'direction']:
setattr(self, attr, getattr(self.parent_property, attr))
self._should_log_debug = logging.is_debug_enabled(self.logger)
@@ -362,7 +362,7 @@ class LazyLoader(AbstractRelationLoader):
return (new_execute, None, None)
def _create_lazy_clause(cls, prop, reverse_direction=False):
- (primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
+ (primaryjoin, secondaryjoin, remote_side) = (prop.primaryjoin, prop.secondaryjoin, prop.remote_side)
binds = {}
equated_columns = {}
@@ -461,7 +461,7 @@ class LoadLazyAttribute(object):
if strategy.use_get:
ident = []
allnulls = True
- for primary_key in prop.select_mapper.primary_key:
+ for primary_key in prop.mapper.primary_key:
val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key])
allnulls = allnulls and val is None
ident.append(val)
@@ -537,7 +537,7 @@ class EagerLoader(AbstractRelationLoader):
try:
clauses = self.clauses[path]
except KeyError:
- clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.polymorphic_primaryjoin, self.parent_property.polymorphic_secondaryjoin, parentclauses)
+ clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.primaryjoin, self.parent_property.secondaryjoin, parentclauses)
self.clauses[path] = clauses
# place the "row_decorator" from the AliasedClauses into the QueryContext, where it will
@@ -554,7 +554,6 @@ class EagerLoader(AbstractRelationLoader):
context.eager_order_by += clauses.secondary.default_order_by()
else:
context.eager_joins = towrap.outerjoin(clauses.alias, clauses.primaryjoin)
-
# ensure all the cols on the parent side are actually in the
# columns clause (i.e. are not deferred), so that aliasing applied by the Query propagates
# those columns outward. This has the effect of "undefering" those columns.
@@ -568,8 +567,8 @@ class EagerLoader(AbstractRelationLoader):
if clauses.order_by:
context.eager_order_by += util.to_list(clauses.order_by)
- for value in self.select_mapper.iterate_properties:
- context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.select_mapper)
+ for value in self.mapper._iterate_polymorphic_properties():
+ context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.mapper)
def _create_row_decorator(self, selectcontext, row, path):
"""Create a *row decorating* function that will apply eager
@@ -593,7 +592,7 @@ class EagerLoader(AbstractRelationLoader):
try:
decorated_row = decorator(row)
# check for identity key
- identity_key = self.select_mapper.identity_key_from_row(decorated_row)
+ identity_key = self.mapper.identity_key_from_row(decorated_row)
# and its good
return decorator
except KeyError, k:
@@ -605,6 +604,7 @@ class EagerLoader(AbstractRelationLoader):
def create_row_processor(self, selectcontext, mapper, row):
row_decorator = self._create_row_decorator(selectcontext, row, selectcontext.path)
+ pathstr = ','.join(str(x) for x in selectcontext.path)
if row_decorator is not None:
def execute(instance, row, isnew, **flags):
decorated_row = row_decorator(row)
@@ -617,11 +617,11 @@ class EagerLoader(AbstractRelationLoader):
# parent object, bypassing InstrumentedAttribute
# event handlers.
#
- instance.__dict__[self.key] = self.select_mapper._instance(selectcontext, decorated_row, None)
+ instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None)
else:
# call _instance on the row, even though the object has been created,
# so that we further descend into properties
- self.select_mapper._instance(selectcontext, decorated_row, None)
+ self.mapper._instance(selectcontext, decorated_row, None)
else:
if isnew or self.key not in instance._state.appenders:
# appender_key can be absent from selectcontext.attributes with isnew=False
@@ -639,8 +639,8 @@ class EagerLoader(AbstractRelationLoader):
result_list = instance._state.appenders[self.key]
if self._should_log_debug:
self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
-
- self.select_mapper._instance(selectcontext, decorated_row, result_list)
+
+ self.mapper._instance(selectcontext, decorated_row, result_list)
if self._should_log_debug:
self.logger.debug("Returning eager instance loader for %s" % str(self))
@@ -689,22 +689,6 @@ class EagerLazyOption(StrategizedOption):
EagerLazyOption.logger = logging.class_logger(EagerLazyOption)
-# TODO: enable FetchMode option. currently
-# this class does nothing. will require Query
-# to swich between using its "polymorphic" selectable
-# and its regular selectable in order to make decisions
-# (therefore might require that FetchModeOperation is performed
-# only as the first operation on a Query.)
-class FetchModeOption(PropertyOption):
- def __init__(self, key, type):
- super(FetchModeOption, self).__init__(key)
- if type not in ('join', 'select'):
- raise exceptions.ArgumentError("Fetchmode must be one of 'join' or 'select'")
- self.type = type
-
- def process_query_property(self, query, properties):
- query.attributes[('fetchmode', properties[-1])] = self.type
-
class RowDecorateOption(PropertyOption):
def __init__(self, key, decorator=None, alias=None):
super(RowDecorateOption, self).__init__(key)
@@ -719,7 +703,7 @@ class RowDecorateOption(PropertyOption):
if isinstance(self.alias, basestring):
self.alias = prop.target.alias(self.alias)
- self.decorator = mapperutil.create_row_adapter(self.alias, prop.target)
+ self.decorator = mapperutil.create_row_adapter(self.alias)
query._attributes[("eager_row_processor", paths[-1])] = self.decorator
RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 8a3583c36..6975f10f8 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -143,82 +143,74 @@ class AliasedClauses(object):
"""Creates aliases of a mapped tables for usage in ORM queries.
"""
- def __init__(self, mapped_table, alias=None):
- if alias:
- self.alias = alias
- else:
- self.alias = mapped_table.alias()
- self.mapped_table = mapped_table
+ def __init__(self, alias, equivalents=None, chain_to=None):
+ self.alias = alias
+ self.equivalents = equivalents
self.row_decorator = self._create_row_adapter()
-
+ self.adapter = sql_util.ClauseAdapter(self.alias, equivalents=equivalents)
+ if chain_to:
+ self.adapter.chain(chain_to.adapter)
+
def aliased_column(self, column):
- """return the aliased version of the given column, creating a new label for it if not already
- present in this AliasedClauses."""
-
+
conv = self.alias.corresponding_column(column)
if conv:
return conv
-
+
aliased_column = column
- # for column-level subqueries, swap out its selectable with our
- # eager version as appropriate, and manually build the
- # "correlation" list of the subquery.
class ModifySubquery(visitors.ClauseVisitor):
def visit_select(s, select):
select._should_correlate = False
select.append_correlation(self.alias)
- aliased_column = sql_util.ClauseAdapter(self.alias).chain(ModifySubquery()).traverse(aliased_column, clone=True)
+ aliased_column = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(ModifySubquery()).traverse(aliased_column, clone=True)
aliased_column = aliased_column.label(None)
- self.row_decorator.map[column] = aliased_column
+ self.row_decorator({}).map[column] = aliased_column
return aliased_column
def adapt_clause(self, clause):
- return sql_util.ClauseAdapter(self.alias).traverse(clause, clone=True)
+ return self.adapter.traverse(clause, clone=True)
def adapt_list(self, clauses):
- return sql_util.ClauseAdapter(self.alias).copy_and_process(clauses)
+ return self.adapter.copy_and_process(clauses)
def _create_row_adapter(self):
- """Return a callable which,
- when passed a RowProxy, will return a new dict-like object
- that translates Column objects to that of this object's Alias before calling upon the row.
-
- This allows a regular Table to be used to target columns in a row that was in reality generated from an alias
- of that table, in such a way that the row can be passed to logic which knows nothing about the aliased form
- of the table.
- """
- return create_row_adapter(self.alias, self.mapped_table)
+ return create_row_adapter(self.alias, equivalent_columns=self.equivalents)
class PropertyAliasedClauses(AliasedClauses):
"""extends AliasedClauses to add support for primary/secondary joins on a relation()."""
def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None, alias=None):
- super(PropertyAliasedClauses, self).__init__(prop.select_table, alias=alias)
-
+ self.prop = prop
+ self.mapper = self.prop.mapper
+ self.table = self.prop.table
self.parentclauses = parentclauses
- self.prop = prop
+ if not alias:
+ from_obj = self.mapper._with_polymorphic_selectable()
+ alias = from_obj.alias()
+
+ super(PropertyAliasedClauses, self).__init__(alias, equivalents=self.mapper._equivalent_columns, chain_to=parentclauses)
if prop.secondary:
self.secondary = prop.secondary.alias()
if parentclauses is not None:
- primary_aliasizer = sql_util.ClauseAdapter(self.secondary).chain(sql_util.ClauseAdapter(parentclauses.alias))
- secondary_aliasizer = sql_util.ClauseAdapter(self.alias).chain(sql_util.ClauseAdapter(self.secondary))
+ primary_aliasizer = sql_util.ClauseAdapter(self.secondary).chain(sql_util.ClauseAdapter(parentclauses.alias, equivalents=parentclauses.equivalents))
+ secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary))
else:
primary_aliasizer = sql_util.ClauseAdapter(self.secondary)
- secondary_aliasizer = sql_util.ClauseAdapter(self.alias).chain(sql_util.ClauseAdapter(self.secondary))
+ secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary))
self.secondaryjoin = secondary_aliasizer.traverse(secondaryjoin, clone=True)
self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
else:
if parentclauses is not None:
- primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
- primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side))
+ primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents)
+ primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side, equivalents=parentclauses.equivalents))
else:
- primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
-
+ primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents)
+
self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
self.secondary = None
self.secondaryjoin = None
@@ -233,10 +225,6 @@ class PropertyAliasedClauses(AliasedClauses):
else:
self.order_by = None
-
- mapper = property(lambda self:self.prop.mapper)
- table = property(lambda self:self.prop.select_table)
-
def instance_str(instance):
"""Return a string describing an instance."""
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index f4611de6d..3d95948cb 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1621,6 +1621,15 @@ class FromClause(Selectable):
from sqlalchemy.sql.util import ClauseAdapter
return ClauseAdapter(alias).traverse(self, clone=True)
+ def correspond_on_equivalents(self, column, equivalents):
+ col = self.corresponding_column(column, require_embedded=True)
+ if col is None and col in equivalents:
+ for equiv in equivalents[col]:
+ nc = self.corresponding_column(equiv, require_embedded=True)
+ if nc:
+ return nc
+ return col
+
def corresponding_column(self, column, require_embedded=False):
"""Given a ``ColumnElement``, return the exported ``ColumnElement``
object from this ``Selectable`` which corresponds to that
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d4163b73b..8ed561e5f 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -93,46 +93,51 @@ def reduce_columns(columns, *clauses):
return expression.ColumnSet(columns.difference(omit))
-def row_adapter(from_, to, equivalent_columns=None):
- """create a row adapter between two selectables.
+class AliasedRow(object):
+
+ def __init__(self, row, map):
+ # AliasedRow objects don't nest, so un-nest
+ # if another AliasedRow was passed
+ if isinstance(row, AliasedRow):
+ self.row = row.row
+ else:
+ self.row = row
+ self.map = map
+
+ def __contains__(self, key):
+ return self.map[key] in self.row
- The returned adapter is a class that can be instantiated repeatedly for any number
- of rows; this is an inexpensive process. However, the creation of the row
- adapter class itself *is* fairly expensive so caching should be used to prevent
- repeated calls to this function.
- """
+ def has_key(self, key):
+ return key in self
- map = {}
- for c in to.c:
- corr = from_.corresponding_column(c)
- if corr:
- map[c] = corr
- elif equivalent_columns:
- if c in equivalent_columns:
- for c2 in equivalent_columns[c]:
- corr = from_.corresponding_column(c2)
- if corr:
- map[c] = corr
- break
-
- class AliasedRow(object):
- def __init__(self, row):
- self.row = row
- def __contains__(self, key):
- if key in map:
- return map[key] in self.row
- else:
- return key in self.row
- def has_key(self, key):
- return key in self
- def __getitem__(self, key):
- if key in map:
- key = map[key]
- return self.row[key]
- def keys(self):
- return map.keys()
- AliasedRow.map = map
- return AliasedRow
+ def __getitem__(self, key):
+ return self.row[self.map[key]]
+
+ def keys(self):
+ return self.row.keys()
+
+def row_adapter(from_, equivalent_columns=None):
+ """create a row adapter against a selectable."""
+
+ if equivalent_columns is None:
+ equivalent_columns = {}
+
+ def locate_col(col):
+ c = from_.corresponding_column(col)
+ if c:
+ return c
+ elif col in equivalent_columns:
+ for c2 in equivalent_columns[col]:
+ corr = from_.corresponding_column(c2)
+ if corr:
+ return corr
+ return col
+
+ map = util.PopulateDict(locate_col)
+
+ def adapt(row):
+ return AliasedRow(row, map)
+ return adapt
class ColumnsInClause(visitors.ClauseVisitor):
"""Given a selectable, visit clauses and determine if any columns
@@ -189,7 +194,7 @@ class ClauseAdapter(visitors.ClauseVisitor):
if not clone:
raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True")
return visitors.ClauseVisitor.traverse(self, obj, clone=True)
-
+
def copy_and_chain(self, adapter):
"""create a copy of this adapter and chain to the given adapter.
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 09d5a0982..7eccc9b89 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -37,17 +37,17 @@ class ClauseVisitor(object):
traverse_chained = traverse_single
def iterate(self, obj):
- """traverse the given expression structure, and return an iterator of all elements."""
+ """traverse the given expression structure, returning an iterator of all elements."""
stack = [obj]
- traversal = []
- while len(stack) > 0:
+ traversal = util.deque()
+ while stack:
t = stack.pop()
- yield t
- traversal.insert(0, t)
+ traversal.appendleft(t)
for c in t.get_children(**self.__traverse_options__):
stack.append(c)
-
+ return iter(traversal)
+
def traverse(self, obj, clone=False):
"""traverse and visit the given expression structure.
@@ -119,32 +119,19 @@ class ClauseVisitor(object):
def clone(element):
return self._clone_element(element, stop_on, cloned)
elem._copy_internals(clone=clone)
-
- for v in self._iterate_visitors:
- meth = getattr(v, "visit_%s" % elem.__visit_name__, None)
- if meth:
- meth(elem)
+
+ self.traverse_single(elem)
for e in elem.get_children(**self.__traverse_options__):
if e not in stop_on:
self._cloned_traversal_impl(e, stop_on, cloned)
return elem
-
+
def _non_cloned_traversal(self, obj):
"""a non-recursive, non-cloning traversal."""
-
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- for target in traversal:
- for v in self._iterate_visitors:
- meth = getattr(v, "visit_%s" % target.__visit_name__, None)
- if meth:
- meth(target)
+
+ for target in self.iterate(obj):
+ self.traverse_single(target)
return obj
def _iterate_visitors(self):
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index af77d792e..90332fdc0 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -163,6 +163,23 @@ except ImportError:
return 'defaultdict(%s, %s)' % (self.default_factory,
dict.__repr__(self))
+try:
+ from collections import deque
+except ImportError:
+ class deque(list):
+ def appendleft(self, x):
+ self.insert(0, x)
+
+ def extendleft(self, iterable):
+ self[0:0] = list(iterable)
+
+ def popleft(self):
+ return self.pop(0)
+
+ def rotate(self, n):
+ for i in xrange(n):
+ self.appendleft(self.pop())
+
def to_list(x, default=None):
if x is None:
return default
@@ -171,6 +188,16 @@ def to_list(x, default=None):
else:
return x
+def starargs_as_list(*args):
+ """interpret the given *args as either a list of *args,
+ or detect if it's a single list and return that.
+
+ """
+ if len(args) == 1:
+ return to_list(args[0], [])
+ else:
+ return list(args)
+
def to_set(x):
if x is None:
return Set()
@@ -1018,7 +1045,36 @@ class symbol(object):
return sym
finally:
symbol._lock.release()
-
+
+def conditional_cache_decorator(func):
+ """apply conditional caching to the return value of a function."""
+
+ return cache_decorator(func, conditional=True)
+
+def cache_decorator(func, conditional=False):
+ """apply caching to the return value of a function."""
+
+ name = '_cached_' + func.__name__
+
+ def do_with_cache(self, *args, **kwargs):
+ if conditional:
+ cache = kwargs.pop('cache', False)
+ if not cache:
+ return func(self, *args, **kwargs)
+ try:
+ return getattr(self, name)
+ except AttributeError:
+ value = func(self, *args, **kwargs)
+ setattr(self, name, value)
+ return value
+ return do_with_cache
+
+def reset_cached(instance, name):
+ try:
+ delattr(instance, '_cached_' + name)
+ except AttributeError:
+ pass
+
def warn(msg):
if isinstance(msg, basestring):
warnings.warn(msg, exceptions.SAWarning, stacklevel=3)