summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r--lib/sqlalchemy/orm/util.py462
1 files changed, 283 insertions, 179 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 43709a58c..a1b0cd5da 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -12,27 +12,51 @@ from .interfaces import PropComparator, MapperProperty
from . import attributes
import re
-from .base import instance_str, state_str, state_class_str, attribute_str, \
- state_attribute_str, object_mapper, object_state, _none_set, _never_set
+from .base import (
+ instance_str,
+ state_str,
+ state_class_str,
+ attribute_str,
+ state_attribute_str,
+ object_mapper,
+ object_state,
+ _none_set,
+ _never_set,
+)
from .base import class_mapper, _class_to_mapper
from .base import InspectionAttr
from .path_registry import PathRegistry
-all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
- "expunge", "save-update", "refresh-expire",
- "none"))
+all_cascades = frozenset(
+ (
+ "delete",
+ "delete-orphan",
+ "all",
+ "merge",
+ "expunge",
+ "save-update",
+ "refresh-expire",
+ "none",
+ )
+)
class CascadeOptions(frozenset):
"""Keeps track of the options sent to relationship().cascade"""
- _add_w_all_cascades = all_cascades.difference([
- 'all', 'none', 'delete-orphan'])
+ _add_w_all_cascades = all_cascades.difference(
+ ["all", "none", "delete-orphan"]
+ )
_allowed_cascades = all_cascades
__slots__ = (
- 'save_update', 'delete', 'refresh_expire', 'merge',
- 'expunge', 'delete_orphan')
+ "save_update",
+ "delete",
+ "refresh_expire",
+ "merge",
+ "expunge",
+ "delete_orphan",
+ )
def __new__(cls, value_list):
if isinstance(value_list, util.string_types) or value_list is None:
@@ -40,60 +64,62 @@ class CascadeOptions(frozenset):
values = set(value_list)
if values.difference(cls._allowed_cascades):
raise sa_exc.ArgumentError(
- "Invalid cascade option(s): %s" %
- ", ".join([repr(x) for x in
- sorted(values.difference(cls._allowed_cascades))]))
+ "Invalid cascade option(s): %s"
+ % ", ".join(
+ [
+ repr(x)
+ for x in sorted(
+ values.difference(cls._allowed_cascades)
+ )
+ ]
+ )
+ )
if "all" in values:
values.update(cls._add_w_all_cascades)
if "none" in values:
values.clear()
- values.discard('all')
+ values.discard("all")
self = frozenset.__new__(CascadeOptions, values)
- self.save_update = 'save-update' in values
- self.delete = 'delete' in values
- self.refresh_expire = 'refresh-expire' in values
- self.merge = 'merge' in values
- self.expunge = 'expunge' in values
+ self.save_update = "save-update" in values
+ self.delete = "delete" in values
+ self.refresh_expire = "refresh-expire" in values
+ self.merge = "merge" in values
+ self.expunge = "expunge" in values
self.delete_orphan = "delete-orphan" in values
if self.delete_orphan and not self.delete:
- util.warn("The 'delete-orphan' cascade "
- "option requires 'delete'.")
+ util.warn(
+ "The 'delete-orphan' cascade " "option requires 'delete'."
+ )
return self
def __repr__(self):
- return "CascadeOptions(%r)" % (
- ",".join([x for x in sorted(self)])
- )
+ return "CascadeOptions(%r)" % (",".join([x for x in sorted(self)]))
@classmethod
def from_string(cls, arg):
- values = [
- c for c
- in re.split(r'\s*,\s*', arg or "")
- if c
- ]
+ values = [c for c in re.split(r"\s*,\s*", arg or "") if c]
return cls(values)
-def _validator_events(
- desc, key, validator, include_removes, include_backrefs):
+def _validator_events(desc, key, validator, include_removes, include_backrefs):
"""Runs a validation method on an attribute value to be set or
appended.
"""
if not include_backrefs:
+
def detect_is_backref(state, initiator):
impl = state.manager[key].impl
return initiator.impl is not impl
if include_removes:
+
def append(state, value, initiator):
- if (
- initiator.op is not attributes.OP_BULK_REPLACE and
- (include_backrefs or not detect_is_backref(state, initiator))
+ if initiator.op is not attributes.OP_BULK_REPLACE and (
+ include_backrefs or not detect_is_backref(state, initiator)
):
return validator(state.obj(), key, value, False)
else:
@@ -103,7 +129,8 @@ def _validator_events(
if include_backrefs or not detect_is_backref(state, initiator):
obj = state.obj()
values[:] = [
- validator(obj, key, value, False) for value in values]
+ validator(obj, key, value, False) for value in values
+ ]
def set_(state, value, oldvalue, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
@@ -116,10 +143,10 @@ def _validator_events(
validator(state.obj(), key, value, True)
else:
+
def append(state, value, initiator):
- if (
- initiator.op is not attributes.OP_BULK_REPLACE and
- (include_backrefs or not detect_is_backref(state, initiator))
+ if initiator.op is not attributes.OP_BULK_REPLACE and (
+ include_backrefs or not detect_is_backref(state, initiator)
):
return validator(state.obj(), key, value)
else:
@@ -128,8 +155,7 @@ def _validator_events(
def bulk_set(state, values, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
obj = state.obj()
- values[:] = [
- validator(obj, key, value) for value in values]
+ values[:] = [validator(obj, key, value) for value in values]
def set_(state, value, oldvalue, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
@@ -137,15 +163,16 @@ def _validator_events(
else:
return value
- event.listen(desc, 'append', append, raw=True, retval=True)
- event.listen(desc, 'bulk_replace', bulk_set, raw=True)
- event.listen(desc, 'set', set_, raw=True, retval=True)
+ event.listen(desc, "append", append, raw=True, retval=True)
+ event.listen(desc, "bulk_replace", bulk_set, raw=True)
+ event.listen(desc, "set", set_, raw=True, retval=True)
if include_removes:
event.listen(desc, "remove", remove, raw=True, retval=True)
-def polymorphic_union(table_map, typecolname,
- aliasname='p_union', cast_nulls=True):
+def polymorphic_union(
+ table_map, typecolname, aliasname="p_union", cast_nulls=True
+):
"""Create a ``UNION`` statement used by a polymorphic mapper.
See :ref:`concrete_inheritance` for an example of how
@@ -197,14 +224,22 @@ def polymorphic_union(table_map, typecolname,
for type, table in table_map.items():
if typecolname is not None:
result.append(
- sql.select([col(name, table) for name in colnames] +
- [sql.literal_column(
- sql_util._quote_ddl_expr(type)).
- label(typecolname)],
- from_obj=[table]))
+ sql.select(
+ [col(name, table) for name in colnames]
+ + [
+ sql.literal_column(
+ sql_util._quote_ddl_expr(type)
+ ).label(typecolname)
+ ],
+ from_obj=[table],
+ )
+ )
else:
- result.append(sql.select([col(name, table) for name in colnames],
- from_obj=[table]))
+ result.append(
+ sql.select(
+ [col(name, table) for name in colnames], from_obj=[table]
+ )
+ )
return sql.union_all(*result).alias(aliasname)
@@ -284,25 +319,29 @@ first()
class_, ident = args
else:
raise sa_exc.ArgumentError(
- "expected up to three positional arguments, "
- "got %s" % largs)
+ "expected up to three positional arguments, " "got %s" % largs
+ )
identity_token = kwargs.pop("identity_token", None)
if kwargs:
- raise sa_exc.ArgumentError("unknown keyword arguments: %s"
- % ", ".join(kwargs))
+ raise sa_exc.ArgumentError(
+ "unknown keyword arguments: %s" % ", ".join(kwargs)
+ )
mapper = class_mapper(class_)
if row is None:
return mapper.identity_key_from_primary_key(
- util.to_list(ident), identity_token=identity_token)
+ util.to_list(ident), identity_token=identity_token
+ )
else:
return mapper.identity_key_from_row(
- row, identity_token=identity_token)
+ row, identity_token=identity_token
+ )
else:
instance = kwargs.pop("instance")
if kwargs:
- raise sa_exc.ArgumentError("unknown keyword arguments: %s"
- % ", ".join(kwargs.keys))
+ raise sa_exc.ArgumentError(
+ "unknown keyword arguments: %s" % ", ".join(kwargs.keys)
+ )
mapper = object_mapper(instance)
return mapper.identity_key_from_instance(instance)
@@ -313,9 +352,15 @@ class ORMAdapter(sql_util.ColumnAdapter):
"""
- def __init__(self, entity, equivalents=None, adapt_required=False,
- chain_to=None, allow_label_resolve=True,
- anonymize_labels=False):
+ def __init__(
+ self,
+ entity,
+ equivalents=None,
+ adapt_required=False,
+ chain_to=None,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ ):
info = inspection.inspect(entity)
self.mapper = info.mapper
@@ -327,15 +372,18 @@ class ORMAdapter(sql_util.ColumnAdapter):
self.aliased_class = None
sql_util.ColumnAdapter.__init__(
- self, selectable, equivalents, chain_to,
+ self,
+ selectable,
+ equivalents,
+ chain_to,
adapt_required=adapt_required,
allow_label_resolve=allow_label_resolve,
anonymize_labels=anonymize_labels,
- include_fn=self._include_fn
+ include_fn=self._include_fn,
)
def _include_fn(self, elem):
- entity = elem._annotations.get('parentmapper', None)
+ entity = elem._annotations.get("parentmapper", None)
return not entity or entity.isa(self.mapper)
@@ -380,20 +428,25 @@ class AliasedClass(object):
"""
- def __init__(self, cls, alias=None,
- name=None,
- flat=False,
- adapt_on_names=False,
- # TODO: None for default here?
- with_polymorphic_mappers=(),
- with_polymorphic_discriminator=None,
- base_alias=None,
- use_mapper_path=False,
- represents_outer_join=False):
+ def __init__(
+ self,
+ cls,
+ alias=None,
+ name=None,
+ flat=False,
+ adapt_on_names=False,
+ # TODO: None for default here?
+ with_polymorphic_mappers=(),
+ with_polymorphic_discriminator=None,
+ base_alias=None,
+ use_mapper_path=False,
+ represents_outer_join=False,
+ ):
mapper = _class_to_mapper(cls)
if alias is None:
alias = mapper._with_polymorphic_selectable.alias(
- name=name, flat=flat)
+ name=name, flat=flat
+ )
self._aliased_insp = AliasedInsp(
self,
@@ -409,14 +462,14 @@ class AliasedClass(object):
base_alias,
use_mapper_path,
adapt_on_names,
- represents_outer_join
+ represents_outer_join,
)
- self.__name__ = 'AliasedClass_%s' % mapper.class_.__name__
+ self.__name__ = "AliasedClass_%s" % mapper.class_.__name__
def __getattr__(self, key):
try:
- _aliased_insp = self.__dict__['_aliased_insp']
+ _aliased_insp = self.__dict__["_aliased_insp"]
except KeyError:
raise AttributeError()
else:
@@ -434,13 +487,13 @@ class AliasedClass(object):
ret = attr.adapt_to_entity(_aliased_insp)
setattr(self, key, ret)
return ret
- elif hasattr(attr, 'func_code'):
+ elif hasattr(attr, "func_code"):
is_method = getattr(_aliased_insp._target, key, None)
if is_method and is_method.__self__ is not None:
return util.types.MethodType(attr.__func__, self, self)
else:
return None
- elif hasattr(attr, '__get__'):
+ elif hasattr(attr, "__get__"):
ret = attr.__get__(None, self)
if isinstance(ret, PropComparator):
return ret.adapt_to_entity(_aliased_insp)
@@ -450,8 +503,10 @@ class AliasedClass(object):
return attr
def __repr__(self):
- return '<AliasedClass at 0x%x; %s>' % (
- id(self), self._aliased_insp._target.__name__)
+ return "<AliasedClass at 0x%x; %s>" % (
+ id(self),
+ self._aliased_insp._target.__name__,
+ )
class AliasedInsp(InspectionAttr):
@@ -490,10 +545,19 @@ class AliasedInsp(InspectionAttr):
"""
- def __init__(self, entity, mapper, selectable, name,
- with_polymorphic_mappers, polymorphic_on,
- _base_alias, _use_mapper_path, adapt_on_names,
- represents_outer_join):
+ def __init__(
+ self,
+ entity,
+ mapper,
+ selectable,
+ name,
+ with_polymorphic_mappers,
+ polymorphic_on,
+ _base_alias,
+ _use_mapper_path,
+ adapt_on_names,
+ represents_outer_join,
+ ):
self.entity = entity
self.mapper = mapper
self.selectable = selectable
@@ -505,18 +569,28 @@ class AliasedInsp(InspectionAttr):
self.represents_outer_join = represents_outer_join
self._adapter = sql_util.ColumnAdapter(
- selectable, equivalents=mapper._equivalent_columns,
- adapt_on_names=adapt_on_names, anonymize_labels=True)
+ selectable,
+ equivalents=mapper._equivalent_columns,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=True,
+ )
self._adapt_on_names = adapt_on_names
self._target = mapper.class_
for poly in self.with_polymorphic_mappers:
if poly is not mapper:
- setattr(self.entity, poly.class_.__name__,
- AliasedClass(poly.class_, selectable, base_alias=self,
- adapt_on_names=adapt_on_names,
- use_mapper_path=_use_mapper_path))
+ setattr(
+ self.entity,
+ poly.class_.__name__,
+ AliasedClass(
+ poly.class_,
+ selectable,
+ base_alias=self,
+ adapt_on_names=adapt_on_names,
+ use_mapper_path=_use_mapper_path,
+ ),
+ )
is_aliased_class = True
"always returns True"
@@ -536,39 +610,35 @@ class AliasedInsp(InspectionAttr):
def __getstate__(self):
return {
- 'entity': self.entity,
- 'mapper': self.mapper,
- 'alias': self.selectable,
- 'name': self.name,
- 'adapt_on_names': self._adapt_on_names,
- 'with_polymorphic_mappers':
- self.with_polymorphic_mappers,
- 'with_polymorphic_discriminator':
- self.polymorphic_on,
- 'base_alias': self._base_alias,
- 'use_mapper_path': self._use_mapper_path,
- 'represents_outer_join': self.represents_outer_join
+ "entity": self.entity,
+ "mapper": self.mapper,
+ "alias": self.selectable,
+ "name": self.name,
+ "adapt_on_names": self._adapt_on_names,
+ "with_polymorphic_mappers": self.with_polymorphic_mappers,
+ "with_polymorphic_discriminator": self.polymorphic_on,
+ "base_alias": self._base_alias,
+ "use_mapper_path": self._use_mapper_path,
+ "represents_outer_join": self.represents_outer_join,
}
def __setstate__(self, state):
self.__init__(
- state['entity'],
- state['mapper'],
- state['alias'],
- state['name'],
- state['with_polymorphic_mappers'],
- state['with_polymorphic_discriminator'],
- state['base_alias'],
- state['use_mapper_path'],
- state['adapt_on_names'],
- state['represents_outer_join']
+ state["entity"],
+ state["mapper"],
+ state["alias"],
+ state["name"],
+ state["with_polymorphic_mappers"],
+ state["with_polymorphic_discriminator"],
+ state["base_alias"],
+ state["use_mapper_path"],
+ state["adapt_on_names"],
+ state["represents_outer_join"],
)
def _adapt_element(self, elem):
- return self._adapter.traverse(elem).\
- _annotate({
- 'parententity': self,
- 'parentmapper': self.mapper}
+ return self._adapter.traverse(elem)._annotate(
+ {"parententity": self, "parentmapper": self.mapper}
)
def _entity_for_mapper(self, mapper):
@@ -578,12 +648,12 @@ class AliasedInsp(InspectionAttr):
return self
else:
return getattr(
- self.entity, mapper.class_.__name__)._aliased_insp
+ self.entity, mapper.class_.__name__
+ )._aliased_insp
elif mapper.isa(self.mapper):
return self
else:
- assert False, "mapper %s doesn't correspond to %s" % (
- mapper, self)
+ assert False, "mapper %s doesn't correspond to %s" % (mapper, self)
@util.memoized_property
def _memoized_values(self):
@@ -599,11 +669,15 @@ class AliasedInsp(InspectionAttr):
def __repr__(self):
if self.with_polymorphic_mappers:
with_poly = "(%s)" % ", ".join(
- mp.class_.__name__ for mp in self.with_polymorphic_mappers)
+ mp.class_.__name__ for mp in self.with_polymorphic_mappers
+ )
else:
with_poly = ""
- return '<AliasedInsp at 0x%x; %s%s>' % (
- id(self), self.class_.__name__, with_poly)
+ return "<AliasedInsp at 0x%x; %s%s>" % (
+ id(self),
+ self.class_.__name__,
+ with_poly,
+ )
inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
@@ -700,15 +774,26 @@ def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False):
)
return element.alias(name, flat=flat)
else:
- return AliasedClass(element, alias=alias, flat=flat,
- name=name, adapt_on_names=adapt_on_names)
+ return AliasedClass(
+ element,
+ alias=alias,
+ flat=flat,
+ name=name,
+ adapt_on_names=adapt_on_names,
+ )
-def with_polymorphic(base, classes, selectable=False,
- flat=False,
- polymorphic_on=None, aliased=False,
- innerjoin=False, _use_mapper_path=False,
- _existing_alias=None):
+def with_polymorphic(
+ base,
+ classes,
+ selectable=False,
+ flat=False,
+ polymorphic_on=None,
+ aliased=False,
+ innerjoin=False,
+ _use_mapper_path=False,
+ _existing_alias=None,
+):
"""Produce an :class:`.AliasedClass` construct which specifies
columns for descendant mappers of the given base.
@@ -777,24 +862,26 @@ def with_polymorphic(base, classes, selectable=False,
if _existing_alias:
assert _existing_alias.mapper is primary_mapper
classes = util.to_set(classes)
- new_classes = set([
- mp.class_ for mp in
- _existing_alias.with_polymorphic_mappers])
+ new_classes = set(
+ [mp.class_ for mp in _existing_alias.with_polymorphic_mappers]
+ )
if classes == new_classes:
return _existing_alias
else:
classes = classes.union(new_classes)
- mappers, selectable = primary_mapper.\
- _with_polymorphic_args(classes, selectable,
- innerjoin=innerjoin)
+ mappers, selectable = primary_mapper._with_polymorphic_args(
+ classes, selectable, innerjoin=innerjoin
+ )
if aliased or flat:
selectable = selectable.alias(flat=flat)
- return AliasedClass(base,
- selectable,
- with_polymorphic_mappers=mappers,
- with_polymorphic_discriminator=polymorphic_on,
- use_mapper_path=_use_mapper_path,
- represents_outer_join=not innerjoin)
+ return AliasedClass(
+ base,
+ selectable,
+ with_polymorphic_mappers=mappers,
+ with_polymorphic_discriminator=polymorphic_on,
+ use_mapper_path=_use_mapper_path,
+ represents_outer_join=not innerjoin,
+ )
def _orm_annotate(element, exclude=None):
@@ -804,7 +891,7 @@ def _orm_annotate(element, exclude=None):
Elements within the exclude collection will be cloned but not annotated.
"""
- return sql_util._deep_annotate(element, {'_orm_adapt': True}, exclude)
+ return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
def _orm_deannotate(element):
@@ -816,9 +903,9 @@ def _orm_deannotate(element):
"""
- return sql_util._deep_deannotate(element,
- values=("_orm_adapt", "parententity")
- )
+ return sql_util._deep_deannotate(
+ element, values=("_orm_adapt", "parententity")
+ )
def _orm_full_deannotate(element):
@@ -831,12 +918,18 @@ class _ORMJoin(expression.Join):
__visit_name__ = expression.Join.__visit_name__
def __init__(
- self,
- left, right, onclause=None, isouter=False,
- full=False, _left_memo=None, _right_memo=None):
+ self,
+ left,
+ right,
+ onclause=None,
+ isouter=False,
+ full=False,
+ _left_memo=None,
+ _right_memo=None,
+ ):
left_info = inspection.inspect(left)
- left_orm_info = getattr(left, '_joined_from_info', left_info)
+ left_orm_info = getattr(left, "_joined_from_info", left_info)
right_info = inspection.inspect(right)
adapt_to = right_info.selectable
@@ -859,19 +952,18 @@ class _ORMJoin(expression.Join):
prop = None
if prop:
- if sql_util.clause_is_present(
- on_selectable, left_info.selectable):
+ if sql_util.clause_is_present(on_selectable, left_info.selectable):
adapt_from = on_selectable
else:
adapt_from = left_info.selectable
- pj, sj, source, dest, \
- secondary, target_adapter = prop._create_joins(
- source_selectable=adapt_from,
- dest_selectable=adapt_to,
- source_polymorphic=True,
- dest_polymorphic=True,
- of_type=right_info.mapper)
+ pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
+ source_selectable=adapt_from,
+ dest_selectable=adapt_to,
+ source_polymorphic=True,
+ dest_polymorphic=True,
+ of_type=right_info.mapper,
+ )
if sj is not None:
if isouter:
@@ -887,8 +979,11 @@ class _ORMJoin(expression.Join):
expression.Join.__init__(self, left, right, onclause, isouter, full)
- if not prop and getattr(right_info, 'mapper', None) \
- and right_info.mapper.single:
+ if (
+ not prop
+ and getattr(right_info, "mapper", None)
+ and right_info.mapper.single
+ ):
# if single inheritance target and we are using a manual
# or implicit ON clause, augment it the same way we'd augment the
# WHERE.
@@ -911,33 +1006,39 @@ class _ORMJoin(expression.Join):
assert self.right is leftmost
left = _ORMJoin(
- self.left, other.left,
- self.onclause, isouter=self.isouter,
+ self.left,
+ other.left,
+ self.onclause,
+ isouter=self.isouter,
_left_memo=self._left_memo,
- _right_memo=other._left_memo
+ _right_memo=other._left_memo,
)
return _ORMJoin(
left,
other.right,
- other.onclause, isouter=other.isouter,
- _right_memo=other._right_memo
+ other.onclause,
+ isouter=other.isouter,
+ _right_memo=other._right_memo,
)
def join(
- self, right, onclause=None,
- isouter=False, full=False, join_to_left=None):
+ self,
+ right,
+ onclause=None,
+ isouter=False,
+ full=False,
+ join_to_left=None,
+ ):
return _ORMJoin(self, right, onclause, full, isouter)
- def outerjoin(
- self, right, onclause=None,
- full=False, join_to_left=None):
+ def outerjoin(self, right, onclause=None, full=False, join_to_left=None):
return _ORMJoin(self, right, onclause, True, full=full)
def join(
- left, right, onclause=None, isouter=False,
- full=False, join_to_left=None):
+ left, right, onclause=None, isouter=False, full=False, join_to_left=None
+):
r"""Produce an inner join between left and right clauses.
:func:`.orm.join` is an extension to the core join interface
@@ -1085,8 +1186,9 @@ def _entity_isa(given, mapper):
"""
if given.is_aliased_class:
- return mapper in given.with_polymorphic_mappers or \
- given.mapper.isa(mapper)
+ return mapper in given.with_polymorphic_mappers or given.mapper.isa(
+ mapper
+ )
elif given.with_polymorphic_mappers:
return mapper in given.with_polymorphic_mappers
else:
@@ -1126,5 +1228,7 @@ def randomize_unitofwork():
from sqlalchemy.orm import unitofwork, session, mapper, dependency
from sqlalchemy.util import topological
from sqlalchemy.testing.util import RandomSet
- topological.set = unitofwork.set = session.set = mapper.set = \
- dependency.set = RandomSet
+
+ topological.set = (
+ unitofwork.set
+ ) = session.set = mapper.set = dependency.set = RandomSet