summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/properties.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/properties.py')
-rw-r--r--lib/sqlalchemy/orm/properties.py269
1 files changed, 208 insertions, 61 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index a00a35ab6..6ce9fd706 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -15,8 +15,11 @@ from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-import sets, random
-from sqlalchemy.orm.interfaces import *
+import operator
+from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator
+from sqlalchemy.exceptions import ArgumentError
+
+__all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef']
class ColumnProperty(StrategizedProperty):
"""Describes an object attribute that corresponds to a table column."""
@@ -31,17 +34,27 @@ class ColumnProperty(StrategizedProperty):
self.columns = list(columns)
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
-
+ self.comparator = ColumnProperty.ColumnComparator(self)
+ # sanity check
+ for col in columns:
+ if not hasattr(col, 'name'):
+ if hasattr(col, 'label'):
+ raise ArgumentError('ColumnProperties must be named for the mapper to work with them. Try .label() to fix this')
+ raise ArgumentError('%r is not a valid candidate for ColumnProperty' % col)
+
def create_strategy(self):
if self.deferred:
return strategies.DeferredColumnLoader(self)
else:
return strategies.ColumnLoader(self)
-
- def getattr(self, object):
+
+ def copy(self):
+ return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
+
+ def getattr(self, object, column):
return getattr(object, self.key)
- def setattr(self, object, value):
+ def setattr(self, object, value, column):
setattr(object, self.key, value)
def get_history(self, obj, passive=False):
@@ -50,19 +63,69 @@ class ColumnProperty(StrategizedProperty):
def merge(self, session, source, dest, _recursive):
setattr(dest, self.key, getattr(source, self.key, None))
- def compare(self, value):
- return self.columns[0] == value
+ def get_col_value(self, column, value):
+ return value
+
+ class ColumnComparator(PropComparator):
+ def clause_element(self):
+ return self.prop.columns[0]
+
+ def operate(self, op, other):
+ return op(self.prop.columns[0], other)
+
+ def reverse_operate(self, op, other):
+ col = self.prop.columns[0]
+ return op(col._bind_param(other), col)
+
ColumnProperty.logger = logging.class_logger(ColumnProperty)
mapper.ColumnProperty = ColumnProperty
+class CompositeProperty(ColumnProperty):
+ """subclasses ColumnProperty to provide composite type support."""
+
+ def __init__(self, class_, *columns, **kwargs):
+ super(CompositeProperty, self).__init__(*columns, **kwargs)
+ self.composite_class = class_
+ self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator(self))
+
+ def copy(self):
+ return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
+
+ def getattr(self, object, column):
+ obj = getattr(object, self.key)
+ return self.get_col_value(column, obj)
+
+ def setattr(self, object, value, column):
+ obj = getattr(object, self.key, None)
+ if obj is None:
+ obj = self.composite_class(*[None for c in self.columns])
+ for a, b in zip(self.columns, value.__colset__()):
+ if a is column:
+ setattr(obj, b, value)
+
+ def get_col_value(self, column, value):
+ for a, b in zip(self.columns, value.__colset__()):
+ if a is column:
+ return b
+
+ class Comparator(PropComparator):
+ def __eq__(self, other):
+ if other is None:
+ return sql.and_(*[a==None for a in self.prop.columns])
+ else:
+ return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())])
+
+ def __ne__(self, other):
+ return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())])
+
class PropertyLoader(StrategizedProperty):
"""Describes an object property that holds a single item or list
of items that correspond to a related database table.
"""
- def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True):
+ def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None):
self.uselist = uselist
self.argument = argument
self.entity_name = entity_name
@@ -80,7 +143,9 @@ class PropertyLoader(StrategizedProperty):
self.remote_side = util.to_set(remote_side)
self.enable_typechecks = enable_typechecks
self._parent_join_cache = {}
-
+ self.comparator = PropertyLoader.Comparator(self)
+ self.join_depth = join_depth
+
if cascade is not None:
self.cascade = mapperutil.CascadeOptions(cascade)
else:
@@ -91,7 +156,7 @@ class PropertyLoader(StrategizedProperty):
self.association = association
self.order_by = order_by
- self.attributeext = attributeext
+ self.attributeext=attributeext
if isinstance(backref, str):
# propigate explicitly sent primary/secondary join conditions to the BackRef object if
# just a string was sent
@@ -104,9 +169,96 @@ class PropertyLoader(StrategizedProperty):
self.backref = backref
self.is_backref = is_backref
- def compare(self, value):
- return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
-
+ class Comparator(PropComparator):
+ def __eq__(self, other):
+ if other is None:
+ return ~sql.exists([1], self.prop.primaryjoin)
+ elif self.prop.uselist:
+ if not hasattr(other, '__iter__'):
+ raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.")
+ else:
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ clauses = []
+ for o in other:
+ clauses.append(
+ sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))]))
+ )
+ return sql.and_(*clauses)
+ else:
+ return self.prop._optimized_compare(other)
+
+ def any(self, criterion=None, **kwargs):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ for k in kwargs:
+ crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+ if criterion is None:
+ criterion = crit
+ else:
+ criterion = criterion & crit
+ return sql.exists([1], j & criterion)
+
+ def has(self, criterion=None, **kwargs):
+ if self.prop.uselist:
+ raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().")
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ for k in kwargs:
+ crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+ if criterion is None:
+ criterion = crit
+ else:
+ criterion = criterion & crit
+ return sql.exists([1], j & criterion)
+
+ def contains(self, other):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==")
+ clause = self.prop._optimized_compare(other)
+
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+
+ clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+ return clause
+
+ def __ne__(self, other):
+ if self.prop.uselist and not hasattr(other, '__iter__'):
+ raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+
+ def compare(self, op, value, value_is_parent=False):
+ if op == operator.eq:
+ if value is None:
+ return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin)
+ else:
+ return self._optimized_compare(value, value_is_parent=value_is_parent)
+ else:
+ return op(self.comparator, value)
+
+ def _optimized_compare(self, value, value_is_parent=False):
+ # optimized operation for ==, uses a lazy clause.
+ (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
+ bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+ class Visitor(sql.ClauseVisitor):
+ def visit_bindparam(s, bindparam):
+ mapper = value_is_parent and self.parent or self.mapper
+ bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
+ Visitor().traverse(criterion)
+ return criterion
+
private = property(lambda s:s.cascade.delete_orphan)
def create_strategy(self):
@@ -127,12 +279,13 @@ class PropertyLoader(StrategizedProperty):
if childlist is None:
return
if self.uselist:
- # sets a blank list according to the correct list class
- dest_list = getattr(self.parent.class_, self.key).initialize(dest)
+ # sets a blank collection according to the correct list class
+ dest_list = sessionlib.attribute_manager.init_collection(dest, self.key)
for current in list(childlist):
obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive)
if obj is not None:
- dest_list.append(obj)
+ #dest_list.append_without_event(obj)
+ dest_list.append_with_event(obj)
else:
current = list(childlist)[0]
if current is not None:
@@ -267,7 +420,7 @@ class PropertyLoader(StrategizedProperty):
if len(self.foreign_keys):
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
if binary.left in self.foreign_keys:
self._opposite_side.add(binary.right)
@@ -280,7 +433,7 @@ class PropertyLoader(StrategizedProperty):
self.foreign_keys = util.Set()
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
# this check is for when the user put the "view_only" flag on and has tables that have nothing
@@ -362,16 +515,13 @@ class PropertyLoader(StrategizedProperty):
"argument." % (str(self)))
def _determine_remote_side(self):
- if len(self.remote_side):
- return
- self.remote_side = util.Set()
+ if not len(self.remote_side):
+ if self.direction is sync.MANYTOONE:
+ self.remote_side = util.Set(self._opposite_side)
+ elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
+ self.remote_side = util.Set(self.foreign_keys)
- if self.direction is sync.MANYTOONE:
- for c in self._opposite_side:
- self.remote_side.add(c)
- elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
- for c in self.foreign_keys:
- self.remote_side.add(c)
+ self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side)
def _create_polymorphic_joins(self):
# get ready to create "polymorphic" primary/secondary join clauses.
@@ -383,27 +533,26 @@ class PropertyLoader(StrategizedProperty):
# as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
# first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
# several "equivalent" columns (such as parent/child fk cols) into just one column.
- target_equivalents = self.mapper._get_inherited_column_equivalents()
+ target_equivalents = self.mapper._get_equivalent_columns()
+
# if the target mapper loads polymorphically, adapt the clauses to the target's selectable
if self.loads_polymorphic:
if self.secondaryjoin:
- self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container()
- sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin)
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
+ self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
+ self.polymorphic_primaryjoin = self.primaryjoin
else:
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
if self.direction is sync.ONETOMANY:
- sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+ self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
elif self.direction is sync.MANYTOONE:
- sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+ self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
self.polymorphic_secondaryjoin = None
# load "polymorphic" versions of the columns present in "remote_side" - this is
# important for lazy-clause generation which goes off the polymorphic target selectable
for c in list(self.remote_side):
- if self.secondary and c in self.secondary.columns:
+ if self.secondary and self.secondary.columns.contains_column(c):
continue
- for equiv in [c] + (c in target_equivalents and target_equivalents[c] or []):
+ for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []):
corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False)
if corr:
self.remote_side.add(corr)
@@ -411,8 +560,8 @@ class PropertyLoader(StrategizedProperty):
else:
raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table))
else:
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
- self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None
+ self.polymorphic_primaryjoin = self.primaryjoin
+ self.polymorphic_secondaryjoin = self.secondaryjoin
def _post_init(self):
if logging.is_info_enabled(self.logger):
@@ -450,22 +599,20 @@ class PropertyLoader(StrategizedProperty):
def _is_self_referential(self):
return self.parent.mapped_table is self.target or self.parent.select_table is self.target
- def get_join(self, parent, primary=True, secondary=True):
+ def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True):
try:
- return self._parent_join_cache[(parent, primary, secondary)]
+ return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)]
except KeyError:
- parent_equivalents = parent._get_inherited_column_equivalents()
- primaryjoin = self.polymorphic_primaryjoin.copy_container()
- if self.secondaryjoin is not None:
- secondaryjoin = self.polymorphic_secondaryjoin.copy_container()
- else:
- secondaryjoin = None
- if self.direction is sync.ONETOMANY:
- sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
- elif self.direction is sync.MANYTOONE:
- sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
- elif self.secondaryjoin:
- sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+ parent_equivalents = parent._get_equivalent_columns()
+ secondaryjoin = self.polymorphic_secondaryjoin
+ if polymorphic_parent:
+ # adapt the "parent" side of our join condition to the "polymorphic" select of the parent
+ if self.direction is sync.ONETOMANY:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+ elif self.direction is sync.MANYTOONE:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+ elif self.secondaryjoin:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
if secondaryjoin is not None:
if secondary and not primary:
@@ -476,7 +623,7 @@ class PropertyLoader(StrategizedProperty):
j = primaryjoin
else:
j = primaryjoin
- self._parent_join_cache[(parent, primary, secondary)] = j
+ self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j
return j
def register_dependencies(self, uowcommit):
@@ -501,7 +648,7 @@ class BackRef(object):
# try to set a LazyLoader on our mapper referencing the parent mapper
mapper = prop.mapper.primary_mapper()
- if not mapper.props.has_key(self.key):
+ if not mapper.get_property(self.key, raiseerr=False) is not None:
pj = self.kwargs.pop('primaryjoin', None)
sj = self.kwargs.pop('secondaryjoin', None)
# the backref property is set on the primary mapper
@@ -512,26 +659,26 @@ class BackRef(object):
backref=prop.key, is_backref=True,
**self.kwargs)
mapper._compile_property(self.key, relation);
- elif not isinstance(mapper.props[self.key], PropertyLoader):
+ elif not isinstance(mapper.get_property(self.key), PropertyLoader):
raise exceptions.ArgumentError(
"Can't create backref '%s' on mapper '%s'; an incompatible "
"property of that name already exists" % (self.key, str(mapper)))
else:
# else set one of us as the "backreference"
parent = prop.parent.primary_mapper()
- if parent.class_ is not mapper.props[self.key]._get_target_class():
+ if parent.class_ is not mapper.get_property(self.key)._get_target_class():
raise exceptions.ArgumentError(
"Backrefs do not match: backref '%s' expects to connect to %s, "
"but found a backref already connected to %s" %
- (self.key, str(parent.class_), str(mapper.props[self.key].mapper.class_)))
- if not mapper.props[self.key].is_backref:
+ (self.key, str(parent.class_), str(mapper.get_property(self.key).mapper.class_)))
+ if not mapper.get_property(self.key).is_backref:
prop.is_backref=True
if not prop.viewonly:
prop._dependency_processor.is_backref=True
# reverse_property used by dependencies.ManyToManyDP to check
# association table operations
- prop.reverse_property = mapper.props[self.key]
- mapper.props[self.key].reverse_property = prop
+ prop.reverse_property = mapper.get_property(self.key)
+ mapper.get_property(self.key).reverse_property = prop
def get_extension(self):
"""Return an attribute extension to use with this backreference."""