diff options
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 462 |
1 files changed, 283 insertions, 179 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 43709a58c..a1b0cd5da 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -12,27 +12,51 @@ from .interfaces import PropComparator, MapperProperty from . import attributes import re -from .base import instance_str, state_str, state_class_str, attribute_str, \ - state_attribute_str, object_mapper, object_state, _none_set, _never_set +from .base import ( + instance_str, + state_str, + state_class_str, + attribute_str, + state_attribute_str, + object_mapper, + object_state, + _none_set, + _never_set, +) from .base import class_mapper, _class_to_mapper from .base import InspectionAttr from .path_registry import PathRegistry -all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", - "expunge", "save-update", "refresh-expire", - "none")) +all_cascades = frozenset( + ( + "delete", + "delete-orphan", + "all", + "merge", + "expunge", + "save-update", + "refresh-expire", + "none", + ) +) class CascadeOptions(frozenset): """Keeps track of the options sent to relationship().cascade""" - _add_w_all_cascades = all_cascades.difference([ - 'all', 'none', 'delete-orphan']) + _add_w_all_cascades = all_cascades.difference( + ["all", "none", "delete-orphan"] + ) _allowed_cascades = all_cascades __slots__ = ( - 'save_update', 'delete', 'refresh_expire', 'merge', - 'expunge', 'delete_orphan') + "save_update", + "delete", + "refresh_expire", + "merge", + "expunge", + "delete_orphan", + ) def __new__(cls, value_list): if isinstance(value_list, util.string_types) or value_list is None: @@ -40,60 +64,62 @@ class CascadeOptions(frozenset): values = set(value_list) if values.difference(cls._allowed_cascades): raise sa_exc.ArgumentError( - "Invalid cascade option(s): %s" % - ", ".join([repr(x) for x in - sorted(values.difference(cls._allowed_cascades))])) + "Invalid cascade option(s): %s" + % ", ".join( + [ + repr(x) + for x in sorted( + values.difference(cls._allowed_cascades) + ) + ] + ) + ) if "all" in values: values.update(cls._add_w_all_cascades) if "none" in values: values.clear() - values.discard('all') + values.discard("all") self = frozenset.__new__(CascadeOptions, values) - self.save_update = 'save-update' in values - self.delete = 'delete' in values - self.refresh_expire = 'refresh-expire' in values - self.merge = 'merge' in values - self.expunge = 'expunge' in values + self.save_update = "save-update" in values + self.delete = "delete" in values + self.refresh_expire = "refresh-expire" in values + self.merge = "merge" in values + self.expunge = "expunge" in values self.delete_orphan = "delete-orphan" in values if self.delete_orphan and not self.delete: - util.warn("The 'delete-orphan' cascade " - "option requires 'delete'.") + util.warn( + "The 'delete-orphan' cascade " "option requires 'delete'." + ) return self def __repr__(self): - return "CascadeOptions(%r)" % ( - ",".join([x for x in sorted(self)]) - ) + return "CascadeOptions(%r)" % (",".join([x for x in sorted(self)])) @classmethod def from_string(cls, arg): - values = [ - c for c - in re.split(r'\s*,\s*', arg or "") - if c - ] + values = [c for c in re.split(r"\s*,\s*", arg or "") if c] return cls(values) -def _validator_events( - desc, key, validator, include_removes, include_backrefs): +def _validator_events(desc, key, validator, include_removes, include_backrefs): """Runs a validation method on an attribute value to be set or appended. """ if not include_backrefs: + def detect_is_backref(state, initiator): impl = state.manager[key].impl return initiator.impl is not impl if include_removes: + def append(state, value, initiator): - if ( - initiator.op is not attributes.OP_BULK_REPLACE and - (include_backrefs or not detect_is_backref(state, initiator)) + if initiator.op is not attributes.OP_BULK_REPLACE and ( + include_backrefs or not detect_is_backref(state, initiator) ): return validator(state.obj(), key, value, False) else: @@ -103,7 +129,8 @@ def _validator_events( if include_backrefs or not detect_is_backref(state, initiator): obj = state.obj() values[:] = [ - validator(obj, key, value, False) for value in values] + validator(obj, key, value, False) for value in values + ] def set_(state, value, oldvalue, initiator): if include_backrefs or not detect_is_backref(state, initiator): @@ -116,10 +143,10 @@ def _validator_events( validator(state.obj(), key, value, True) else: + def append(state, value, initiator): - if ( - initiator.op is not attributes.OP_BULK_REPLACE and - (include_backrefs or not detect_is_backref(state, initiator)) + if initiator.op is not attributes.OP_BULK_REPLACE and ( + include_backrefs or not detect_is_backref(state, initiator) ): return validator(state.obj(), key, value) else: @@ -128,8 +155,7 @@ def _validator_events( def bulk_set(state, values, initiator): if include_backrefs or not detect_is_backref(state, initiator): obj = state.obj() - values[:] = [ - validator(obj, key, value) for value in values] + values[:] = [validator(obj, key, value) for value in values] def set_(state, value, oldvalue, initiator): if include_backrefs or not detect_is_backref(state, initiator): @@ -137,15 +163,16 @@ def _validator_events( else: return value - event.listen(desc, 'append', append, raw=True, retval=True) - event.listen(desc, 'bulk_replace', bulk_set, raw=True) - event.listen(desc, 'set', set_, raw=True, retval=True) + event.listen(desc, "append", append, raw=True, retval=True) + event.listen(desc, "bulk_replace", bulk_set, raw=True) + event.listen(desc, "set", set_, raw=True, retval=True) if include_removes: event.listen(desc, "remove", remove, raw=True, retval=True) -def polymorphic_union(table_map, typecolname, - aliasname='p_union', cast_nulls=True): +def polymorphic_union( + table_map, typecolname, aliasname="p_union", cast_nulls=True +): """Create a ``UNION`` statement used by a polymorphic mapper. See :ref:`concrete_inheritance` for an example of how @@ -197,14 +224,22 @@ def polymorphic_union(table_map, typecolname, for type, table in table_map.items(): if typecolname is not None: result.append( - sql.select([col(name, table) for name in colnames] + - [sql.literal_column( - sql_util._quote_ddl_expr(type)). - label(typecolname)], - from_obj=[table])) + sql.select( + [col(name, table) for name in colnames] + + [ + sql.literal_column( + sql_util._quote_ddl_expr(type) + ).label(typecolname) + ], + from_obj=[table], + ) + ) else: - result.append(sql.select([col(name, table) for name in colnames], - from_obj=[table])) + result.append( + sql.select( + [col(name, table) for name in colnames], from_obj=[table] + ) + ) return sql.union_all(*result).alias(aliasname) @@ -284,25 +319,29 @@ first() class_, ident = args else: raise sa_exc.ArgumentError( - "expected up to three positional arguments, " - "got %s" % largs) + "expected up to three positional arguments, " "got %s" % largs + ) identity_token = kwargs.pop("identity_token", None) if kwargs: - raise sa_exc.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs)) + raise sa_exc.ArgumentError( + "unknown keyword arguments: %s" % ", ".join(kwargs) + ) mapper = class_mapper(class_) if row is None: return mapper.identity_key_from_primary_key( - util.to_list(ident), identity_token=identity_token) + util.to_list(ident), identity_token=identity_token + ) else: return mapper.identity_key_from_row( - row, identity_token=identity_token) + row, identity_token=identity_token + ) else: instance = kwargs.pop("instance") if kwargs: - raise sa_exc.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs.keys)) + raise sa_exc.ArgumentError( + "unknown keyword arguments: %s" % ", ".join(kwargs.keys) + ) mapper = object_mapper(instance) return mapper.identity_key_from_instance(instance) @@ -313,9 +352,15 @@ class ORMAdapter(sql_util.ColumnAdapter): """ - def __init__(self, entity, equivalents=None, adapt_required=False, - chain_to=None, allow_label_resolve=True, - anonymize_labels=False): + def __init__( + self, + entity, + equivalents=None, + adapt_required=False, + chain_to=None, + allow_label_resolve=True, + anonymize_labels=False, + ): info = inspection.inspect(entity) self.mapper = info.mapper @@ -327,15 +372,18 @@ class ORMAdapter(sql_util.ColumnAdapter): self.aliased_class = None sql_util.ColumnAdapter.__init__( - self, selectable, equivalents, chain_to, + self, + selectable, + equivalents, + chain_to, adapt_required=adapt_required, allow_label_resolve=allow_label_resolve, anonymize_labels=anonymize_labels, - include_fn=self._include_fn + include_fn=self._include_fn, ) def _include_fn(self, elem): - entity = elem._annotations.get('parentmapper', None) + entity = elem._annotations.get("parentmapper", None) return not entity or entity.isa(self.mapper) @@ -380,20 +428,25 @@ class AliasedClass(object): """ - def __init__(self, cls, alias=None, - name=None, - flat=False, - adapt_on_names=False, - # TODO: None for default here? - with_polymorphic_mappers=(), - with_polymorphic_discriminator=None, - base_alias=None, - use_mapper_path=False, - represents_outer_join=False): + def __init__( + self, + cls, + alias=None, + name=None, + flat=False, + adapt_on_names=False, + # TODO: None for default here? + with_polymorphic_mappers=(), + with_polymorphic_discriminator=None, + base_alias=None, + use_mapper_path=False, + represents_outer_join=False, + ): mapper = _class_to_mapper(cls) if alias is None: alias = mapper._with_polymorphic_selectable.alias( - name=name, flat=flat) + name=name, flat=flat + ) self._aliased_insp = AliasedInsp( self, @@ -409,14 +462,14 @@ class AliasedClass(object): base_alias, use_mapper_path, adapt_on_names, - represents_outer_join + represents_outer_join, ) - self.__name__ = 'AliasedClass_%s' % mapper.class_.__name__ + self.__name__ = "AliasedClass_%s" % mapper.class_.__name__ def __getattr__(self, key): try: - _aliased_insp = self.__dict__['_aliased_insp'] + _aliased_insp = self.__dict__["_aliased_insp"] except KeyError: raise AttributeError() else: @@ -434,13 +487,13 @@ class AliasedClass(object): ret = attr.adapt_to_entity(_aliased_insp) setattr(self, key, ret) return ret - elif hasattr(attr, 'func_code'): + elif hasattr(attr, "func_code"): is_method = getattr(_aliased_insp._target, key, None) if is_method and is_method.__self__ is not None: return util.types.MethodType(attr.__func__, self, self) else: return None - elif hasattr(attr, '__get__'): + elif hasattr(attr, "__get__"): ret = attr.__get__(None, self) if isinstance(ret, PropComparator): return ret.adapt_to_entity(_aliased_insp) @@ -450,8 +503,10 @@ class AliasedClass(object): return attr def __repr__(self): - return '<AliasedClass at 0x%x; %s>' % ( - id(self), self._aliased_insp._target.__name__) + return "<AliasedClass at 0x%x; %s>" % ( + id(self), + self._aliased_insp._target.__name__, + ) class AliasedInsp(InspectionAttr): @@ -490,10 +545,19 @@ class AliasedInsp(InspectionAttr): """ - def __init__(self, entity, mapper, selectable, name, - with_polymorphic_mappers, polymorphic_on, - _base_alias, _use_mapper_path, adapt_on_names, - represents_outer_join): + def __init__( + self, + entity, + mapper, + selectable, + name, + with_polymorphic_mappers, + polymorphic_on, + _base_alias, + _use_mapper_path, + adapt_on_names, + represents_outer_join, + ): self.entity = entity self.mapper = mapper self.selectable = selectable @@ -505,18 +569,28 @@ class AliasedInsp(InspectionAttr): self.represents_outer_join = represents_outer_join self._adapter = sql_util.ColumnAdapter( - selectable, equivalents=mapper._equivalent_columns, - adapt_on_names=adapt_on_names, anonymize_labels=True) + selectable, + equivalents=mapper._equivalent_columns, + adapt_on_names=adapt_on_names, + anonymize_labels=True, + ) self._adapt_on_names = adapt_on_names self._target = mapper.class_ for poly in self.with_polymorphic_mappers: if poly is not mapper: - setattr(self.entity, poly.class_.__name__, - AliasedClass(poly.class_, selectable, base_alias=self, - adapt_on_names=adapt_on_names, - use_mapper_path=_use_mapper_path)) + setattr( + self.entity, + poly.class_.__name__, + AliasedClass( + poly.class_, + selectable, + base_alias=self, + adapt_on_names=adapt_on_names, + use_mapper_path=_use_mapper_path, + ), + ) is_aliased_class = True "always returns True" @@ -536,39 +610,35 @@ class AliasedInsp(InspectionAttr): def __getstate__(self): return { - 'entity': self.entity, - 'mapper': self.mapper, - 'alias': self.selectable, - 'name': self.name, - 'adapt_on_names': self._adapt_on_names, - 'with_polymorphic_mappers': - self.with_polymorphic_mappers, - 'with_polymorphic_discriminator': - self.polymorphic_on, - 'base_alias': self._base_alias, - 'use_mapper_path': self._use_mapper_path, - 'represents_outer_join': self.represents_outer_join + "entity": self.entity, + "mapper": self.mapper, + "alias": self.selectable, + "name": self.name, + "adapt_on_names": self._adapt_on_names, + "with_polymorphic_mappers": self.with_polymorphic_mappers, + "with_polymorphic_discriminator": self.polymorphic_on, + "base_alias": self._base_alias, + "use_mapper_path": self._use_mapper_path, + "represents_outer_join": self.represents_outer_join, } def __setstate__(self, state): self.__init__( - state['entity'], - state['mapper'], - state['alias'], - state['name'], - state['with_polymorphic_mappers'], - state['with_polymorphic_discriminator'], - state['base_alias'], - state['use_mapper_path'], - state['adapt_on_names'], - state['represents_outer_join'] + state["entity"], + state["mapper"], + state["alias"], + state["name"], + state["with_polymorphic_mappers"], + state["with_polymorphic_discriminator"], + state["base_alias"], + state["use_mapper_path"], + state["adapt_on_names"], + state["represents_outer_join"], ) def _adapt_element(self, elem): - return self._adapter.traverse(elem).\ - _annotate({ - 'parententity': self, - 'parentmapper': self.mapper} + return self._adapter.traverse(elem)._annotate( + {"parententity": self, "parentmapper": self.mapper} ) def _entity_for_mapper(self, mapper): @@ -578,12 +648,12 @@ class AliasedInsp(InspectionAttr): return self else: return getattr( - self.entity, mapper.class_.__name__)._aliased_insp + self.entity, mapper.class_.__name__ + )._aliased_insp elif mapper.isa(self.mapper): return self else: - assert False, "mapper %s doesn't correspond to %s" % ( - mapper, self) + assert False, "mapper %s doesn't correspond to %s" % (mapper, self) @util.memoized_property def _memoized_values(self): @@ -599,11 +669,15 @@ class AliasedInsp(InspectionAttr): def __repr__(self): if self.with_polymorphic_mappers: with_poly = "(%s)" % ", ".join( - mp.class_.__name__ for mp in self.with_polymorphic_mappers) + mp.class_.__name__ for mp in self.with_polymorphic_mappers + ) else: with_poly = "" - return '<AliasedInsp at 0x%x; %s%s>' % ( - id(self), self.class_.__name__, with_poly) + return "<AliasedInsp at 0x%x; %s%s>" % ( + id(self), + self.class_.__name__, + with_poly, + ) inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) @@ -700,15 +774,26 @@ def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): ) return element.alias(name, flat=flat) else: - return AliasedClass(element, alias=alias, flat=flat, - name=name, adapt_on_names=adapt_on_names) + return AliasedClass( + element, + alias=alias, + flat=flat, + name=name, + adapt_on_names=adapt_on_names, + ) -def with_polymorphic(base, classes, selectable=False, - flat=False, - polymorphic_on=None, aliased=False, - innerjoin=False, _use_mapper_path=False, - _existing_alias=None): +def with_polymorphic( + base, + classes, + selectable=False, + flat=False, + polymorphic_on=None, + aliased=False, + innerjoin=False, + _use_mapper_path=False, + _existing_alias=None, +): """Produce an :class:`.AliasedClass` construct which specifies columns for descendant mappers of the given base. @@ -777,24 +862,26 @@ def with_polymorphic(base, classes, selectable=False, if _existing_alias: assert _existing_alias.mapper is primary_mapper classes = util.to_set(classes) - new_classes = set([ - mp.class_ for mp in - _existing_alias.with_polymorphic_mappers]) + new_classes = set( + [mp.class_ for mp in _existing_alias.with_polymorphic_mappers] + ) if classes == new_classes: return _existing_alias else: classes = classes.union(new_classes) - mappers, selectable = primary_mapper.\ - _with_polymorphic_args(classes, selectable, - innerjoin=innerjoin) + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, selectable, innerjoin=innerjoin + ) if aliased or flat: selectable = selectable.alias(flat=flat) - return AliasedClass(base, - selectable, - with_polymorphic_mappers=mappers, - with_polymorphic_discriminator=polymorphic_on, - use_mapper_path=_use_mapper_path, - represents_outer_join=not innerjoin) + return AliasedClass( + base, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path, + represents_outer_join=not innerjoin, + ) def _orm_annotate(element, exclude=None): @@ -804,7 +891,7 @@ def _orm_annotate(element, exclude=None): Elements within the exclude collection will be cloned but not annotated. """ - return sql_util._deep_annotate(element, {'_orm_adapt': True}, exclude) + return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude) def _orm_deannotate(element): @@ -816,9 +903,9 @@ def _orm_deannotate(element): """ - return sql_util._deep_deannotate(element, - values=("_orm_adapt", "parententity") - ) + return sql_util._deep_deannotate( + element, values=("_orm_adapt", "parententity") + ) def _orm_full_deannotate(element): @@ -831,12 +918,18 @@ class _ORMJoin(expression.Join): __visit_name__ = expression.Join.__visit_name__ def __init__( - self, - left, right, onclause=None, isouter=False, - full=False, _left_memo=None, _right_memo=None): + self, + left, + right, + onclause=None, + isouter=False, + full=False, + _left_memo=None, + _right_memo=None, + ): left_info = inspection.inspect(left) - left_orm_info = getattr(left, '_joined_from_info', left_info) + left_orm_info = getattr(left, "_joined_from_info", left_info) right_info = inspection.inspect(right) adapt_to = right_info.selectable @@ -859,19 +952,18 @@ class _ORMJoin(expression.Join): prop = None if prop: - if sql_util.clause_is_present( - on_selectable, left_info.selectable): + if sql_util.clause_is_present(on_selectable, left_info.selectable): adapt_from = on_selectable else: adapt_from = left_info.selectable - pj, sj, source, dest, \ - secondary, target_adapter = prop._create_joins( - source_selectable=adapt_from, - dest_selectable=adapt_to, - source_polymorphic=True, - dest_polymorphic=True, - of_type=right_info.mapper) + pj, sj, source, dest, secondary, target_adapter = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + dest_polymorphic=True, + of_type=right_info.mapper, + ) if sj is not None: if isouter: @@ -887,8 +979,11 @@ class _ORMJoin(expression.Join): expression.Join.__init__(self, left, right, onclause, isouter, full) - if not prop and getattr(right_info, 'mapper', None) \ - and right_info.mapper.single: + if ( + not prop + and getattr(right_info, "mapper", None) + and right_info.mapper.single + ): # if single inheritance target and we are using a manual # or implicit ON clause, augment it the same way we'd augment the # WHERE. @@ -911,33 +1006,39 @@ class _ORMJoin(expression.Join): assert self.right is leftmost left = _ORMJoin( - self.left, other.left, - self.onclause, isouter=self.isouter, + self.left, + other.left, + self.onclause, + isouter=self.isouter, _left_memo=self._left_memo, - _right_memo=other._left_memo + _right_memo=other._left_memo, ) return _ORMJoin( left, other.right, - other.onclause, isouter=other.isouter, - _right_memo=other._right_memo + other.onclause, + isouter=other.isouter, + _right_memo=other._right_memo, ) def join( - self, right, onclause=None, - isouter=False, full=False, join_to_left=None): + self, + right, + onclause=None, + isouter=False, + full=False, + join_to_left=None, + ): return _ORMJoin(self, right, onclause, full, isouter) - def outerjoin( - self, right, onclause=None, - full=False, join_to_left=None): + def outerjoin(self, right, onclause=None, full=False, join_to_left=None): return _ORMJoin(self, right, onclause, True, full=full) def join( - left, right, onclause=None, isouter=False, - full=False, join_to_left=None): + left, right, onclause=None, isouter=False, full=False, join_to_left=None +): r"""Produce an inner join between left and right clauses. :func:`.orm.join` is an extension to the core join interface @@ -1085,8 +1186,9 @@ def _entity_isa(given, mapper): """ if given.is_aliased_class: - return mapper in given.with_polymorphic_mappers or \ - given.mapper.isa(mapper) + return mapper in given.with_polymorphic_mappers or given.mapper.isa( + mapper + ) elif given.with_polymorphic_mappers: return mapper in given.with_polymorphic_mappers else: @@ -1126,5 +1228,7 @@ def randomize_unitofwork(): from sqlalchemy.orm import unitofwork, session, mapper, dependency from sqlalchemy.util import topological from sqlalchemy.testing.util import RandomSet - topological.set = unitofwork.set = session.set = mapper.set = \ - dependency.set = RandomSet + + topological.set = ( + unitofwork.set + ) = session.set = mapper.set = dependency.set = RandomSet |