diff options
Diffstat (limited to 'lib/sqlalchemy/orm/properties.py')
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 269 |
1 files changed, 208 insertions, 61 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index a00a35ab6..6ce9fd706 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -15,8 +15,11 @@ from sqlalchemy import sql, schema, util, exceptions, sql_util, logging from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil -import sets, random -from sqlalchemy.orm.interfaces import * +import operator +from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator +from sqlalchemy.exceptions import ArgumentError + +__all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef'] class ColumnProperty(StrategizedProperty): """Describes an object attribute that corresponds to a table column.""" @@ -31,17 +34,27 @@ class ColumnProperty(StrategizedProperty): self.columns = list(columns) self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) - + self.comparator = ColumnProperty.ColumnComparator(self) + # sanity check + for col in columns: + if not hasattr(col, 'name'): + if hasattr(col, 'label'): + raise ArgumentError('ColumnProperties must be named for the mapper to work with them. Try .label() to fix this') + raise ArgumentError('%r is not a valid candidate for ColumnProperty' % col) + def create_strategy(self): if self.deferred: return strategies.DeferredColumnLoader(self) else: return strategies.ColumnLoader(self) - - def getattr(self, object): + + def copy(self): + return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) + + def getattr(self, object, column): return getattr(object, self.key) - def setattr(self, object, value): + def setattr(self, object, value, column): setattr(object, self.key, value) def get_history(self, obj, passive=False): @@ -50,19 +63,69 @@ class ColumnProperty(StrategizedProperty): def merge(self, session, source, dest, _recursive): setattr(dest, self.key, getattr(source, self.key, None)) - def compare(self, value): - return self.columns[0] == value + def get_col_value(self, column, value): + return value + + class ColumnComparator(PropComparator): + def clause_element(self): + return self.prop.columns[0] + + def operate(self, op, other): + return op(self.prop.columns[0], other) + + def reverse_operate(self, op, other): + col = self.prop.columns[0] + return op(col._bind_param(other), col) + ColumnProperty.logger = logging.class_logger(ColumnProperty) mapper.ColumnProperty = ColumnProperty +class CompositeProperty(ColumnProperty): + """subclasses ColumnProperty to provide composite type support.""" + + def __init__(self, class_, *columns, **kwargs): + super(CompositeProperty, self).__init__(*columns, **kwargs) + self.composite_class = class_ + self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator(self)) + + def copy(self): + return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) + + def getattr(self, object, column): + obj = getattr(object, self.key) + return self.get_col_value(column, obj) + + def setattr(self, object, value, column): + obj = getattr(object, self.key, None) + if obj is None: + obj = self.composite_class(*[None for c in self.columns]) + for a, b in zip(self.columns, value.__colset__()): + if a is column: + setattr(obj, b, value) + + def get_col_value(self, column, value): + for a, b in zip(self.columns, value.__colset__()): + if a is column: + return b + + class Comparator(PropComparator): + def __eq__(self, other): + if other is None: + return sql.and_(*[a==None for a in self.prop.columns]) + else: + return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())]) + + def __ne__(self, other): + return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())]) + class PropertyLoader(StrategizedProperty): """Describes an object property that holds a single item or list of items that correspond to a related database table. """ - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True): + def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None): self.uselist = uselist self.argument = argument self.entity_name = entity_name @@ -80,7 +143,9 @@ class PropertyLoader(StrategizedProperty): self.remote_side = util.to_set(remote_side) self.enable_typechecks = enable_typechecks self._parent_join_cache = {} - + self.comparator = PropertyLoader.Comparator(self) + self.join_depth = join_depth + if cascade is not None: self.cascade = mapperutil.CascadeOptions(cascade) else: @@ -91,7 +156,7 @@ class PropertyLoader(StrategizedProperty): self.association = association self.order_by = order_by - self.attributeext = attributeext + self.attributeext=attributeext if isinstance(backref, str): # propigate explicitly sent primary/secondary join conditions to the BackRef object if # just a string was sent @@ -104,9 +169,96 @@ class PropertyLoader(StrategizedProperty): self.backref = backref self.is_backref = is_backref - def compare(self, value): - return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))]) - + class Comparator(PropComparator): + def __eq__(self, other): + if other is None: + return ~sql.exists([1], self.prop.primaryjoin) + elif self.prop.uselist: + if not hasattr(other, '__iter__'): + raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.") + else: + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + clauses = [] + for o in other: + clauses.append( + sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))])) + ) + return sql.and_(*clauses) + else: + return self.prop._optimized_compare(other) + + def any(self, criterion=None, **kwargs): + if not self.prop.uselist: + raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + for k in kwargs: + crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) + if criterion is None: + criterion = crit + else: + criterion = criterion & crit + return sql.exists([1], j & criterion) + + def has(self, criterion=None, **kwargs): + if self.prop.uselist: + raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().") + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + for k in kwargs: + crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) + if criterion is None: + criterion = crit + else: + criterion = criterion & crit + return sql.exists([1], j & criterion) + + def contains(self, other): + if not self.prop.uselist: + raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") + clause = self.prop._optimized_compare(other) + + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + + clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) + return clause + + def __ne__(self, other): + if self.prop.uselist and not hasattr(other, '__iter__'): + raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") + + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) + + def compare(self, op, value, value_is_parent=False): + if op == operator.eq: + if value is None: + return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin) + else: + return self._optimized_compare(value, value_is_parent=value_is_parent) + else: + return op(self.comparator, value) + + def _optimized_compare(self, value, value_is_parent=False): + # optimized operation for ==, uses a lazy clause. + (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent) + bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) + + class Visitor(sql.ClauseVisitor): + def visit_bindparam(s, bindparam): + mapper = value_is_parent and self.parent or self.mapper + bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key]) + Visitor().traverse(criterion) + return criterion + private = property(lambda s:s.cascade.delete_orphan) def create_strategy(self): @@ -127,12 +279,13 @@ class PropertyLoader(StrategizedProperty): if childlist is None: return if self.uselist: - # sets a blank list according to the correct list class - dest_list = getattr(self.parent.class_, self.key).initialize(dest) + # sets a blank collection according to the correct list class + dest_list = sessionlib.attribute_manager.init_collection(dest, self.key) for current in list(childlist): obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive) if obj is not None: - dest_list.append(obj) + #dest_list.append_without_event(obj) + dest_list.append_with_event(obj) else: current = list(childlist)[0] if current is not None: @@ -267,7 +420,7 @@ class PropertyLoader(StrategizedProperty): if len(self.foreign_keys): self._opposite_side = util.Set() def visit_binary(binary): - if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return if binary.left in self.foreign_keys: self._opposite_side.add(binary.right) @@ -280,7 +433,7 @@ class PropertyLoader(StrategizedProperty): self.foreign_keys = util.Set() self._opposite_side = util.Set() def visit_binary(binary): - if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return # this check is for when the user put the "view_only" flag on and has tables that have nothing @@ -362,16 +515,13 @@ class PropertyLoader(StrategizedProperty): "argument." % (str(self))) def _determine_remote_side(self): - if len(self.remote_side): - return - self.remote_side = util.Set() + if not len(self.remote_side): + if self.direction is sync.MANYTOONE: + self.remote_side = util.Set(self._opposite_side) + elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY: + self.remote_side = util.Set(self.foreign_keys) - if self.direction is sync.MANYTOONE: - for c in self._opposite_side: - self.remote_side.add(c) - elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY: - for c in self.foreign_keys: - self.remote_side.add(c) + self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side) def _create_polymorphic_joins(self): # get ready to create "polymorphic" primary/secondary join clauses. @@ -383,27 +533,26 @@ class PropertyLoader(StrategizedProperty): # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out, # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge # several "equivalent" columns (such as parent/child fk cols) into just one column. - target_equivalents = self.mapper._get_inherited_column_equivalents() + target_equivalents = self.mapper._get_equivalent_columns() + # if the target mapper loads polymorphically, adapt the clauses to the target's selectable if self.loads_polymorphic: if self.secondaryjoin: - self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container() - sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin) - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() + self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True) + self.polymorphic_primaryjoin = self.primaryjoin else: - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() if self.direction is sync.ONETOMANY: - sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin) + self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) elif self.direction is sync.MANYTOONE: - sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin) + self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) self.polymorphic_secondaryjoin = None # load "polymorphic" versions of the columns present in "remote_side" - this is # important for lazy-clause generation which goes off the polymorphic target selectable for c in list(self.remote_side): - if self.secondary and c in self.secondary.columns: + if self.secondary and self.secondary.columns.contains_column(c): continue - for equiv in [c] + (c in target_equivalents and target_equivalents[c] or []): + for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []): corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False) if corr: self.remote_side.add(corr) @@ -411,8 +560,8 @@ class PropertyLoader(StrategizedProperty): else: raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table)) else: - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() - self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None + self.polymorphic_primaryjoin = self.primaryjoin + self.polymorphic_secondaryjoin = self.secondaryjoin def _post_init(self): if logging.is_info_enabled(self.logger): @@ -450,22 +599,20 @@ class PropertyLoader(StrategizedProperty): def _is_self_referential(self): return self.parent.mapped_table is self.target or self.parent.select_table is self.target - def get_join(self, parent, primary=True, secondary=True): + def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True): try: - return self._parent_join_cache[(parent, primary, secondary)] + return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] except KeyError: - parent_equivalents = parent._get_inherited_column_equivalents() - primaryjoin = self.polymorphic_primaryjoin.copy_container() - if self.secondaryjoin is not None: - secondaryjoin = self.polymorphic_secondaryjoin.copy_container() - else: - secondaryjoin = None - if self.direction is sync.ONETOMANY: - sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) - elif self.direction is sync.MANYTOONE: - sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) - elif self.secondaryjoin: - sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) + parent_equivalents = parent._get_equivalent_columns() + secondaryjoin = self.polymorphic_secondaryjoin + if polymorphic_parent: + # adapt the "parent" side of our join condition to the "polymorphic" select of the parent + if self.direction is sync.ONETOMANY: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + elif self.direction is sync.MANYTOONE: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + elif self.secondaryjoin: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) if secondaryjoin is not None: if secondary and not primary: @@ -476,7 +623,7 @@ class PropertyLoader(StrategizedProperty): j = primaryjoin else: j = primaryjoin - self._parent_join_cache[(parent, primary, secondary)] = j + self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j return j def register_dependencies(self, uowcommit): @@ -501,7 +648,7 @@ class BackRef(object): # try to set a LazyLoader on our mapper referencing the parent mapper mapper = prop.mapper.primary_mapper() - if not mapper.props.has_key(self.key): + if not mapper.get_property(self.key, raiseerr=False) is not None: pj = self.kwargs.pop('primaryjoin', None) sj = self.kwargs.pop('secondaryjoin', None) # the backref property is set on the primary mapper @@ -512,26 +659,26 @@ class BackRef(object): backref=prop.key, is_backref=True, **self.kwargs) mapper._compile_property(self.key, relation); - elif not isinstance(mapper.props[self.key], PropertyLoader): + elif not isinstance(mapper.get_property(self.key), PropertyLoader): raise exceptions.ArgumentError( "Can't create backref '%s' on mapper '%s'; an incompatible " "property of that name already exists" % (self.key, str(mapper))) else: # else set one of us as the "backreference" parent = prop.parent.primary_mapper() - if parent.class_ is not mapper.props[self.key]._get_target_class(): + if parent.class_ is not mapper.get_property(self.key)._get_target_class(): raise exceptions.ArgumentError( "Backrefs do not match: backref '%s' expects to connect to %s, " "but found a backref already connected to %s" % - (self.key, str(parent.class_), str(mapper.props[self.key].mapper.class_))) - if not mapper.props[self.key].is_backref: + (self.key, str(parent.class_), str(mapper.get_property(self.key).mapper.class_))) + if not mapper.get_property(self.key).is_backref: prop.is_backref=True if not prop.viewonly: prop._dependency_processor.is_backref=True # reverse_property used by dependencies.ManyToManyDP to check # association table operations - prop.reverse_property = mapper.props[self.key] - mapper.props[self.key].reverse_property = prop + prop.reverse_property = mapper.get_property(self.key) + mapper.get_property(self.key).reverse_property = prop def get_extension(self): """Return an attribute extension to use with this backreference.""" |