diff options
Diffstat (limited to 'lib/sqlalchemy/orm/interfaces.py')
-rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 321 |
1 files changed, 45 insertions, 276 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 2f4aa5208..18723e4f6 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -21,7 +21,6 @@ from __future__ import absolute_import from .. import exc as sa_exc, util, inspect from ..sql import operators from collections import deque -from .base import _is_aliased_class, _class_to_mapper from .base import ONETOMANY, MANYTOONE, MANYTOMANY, EXT_CONTINUE, EXT_STOP, NOT_EXTENSION from .base import _InspectionAttr, _MappedAttribute from .path_registry import PathRegistry @@ -424,51 +423,57 @@ class StrategizedProperty(MapperProperty): strategy_wildcard_key = None - @util.memoized_property - def _wildcard_path(self): - if self.strategy_wildcard_key: - return ('loaderstrategy', (self.strategy_wildcard_key,)) - else: - return None + def _get_context_loader(self, context, path): + load = None - def _get_context_strategy(self, context, path): - strategy_cls = path._inlined_get_for(self, context, 'loaderstrategy') + # use EntityRegistry.__getitem__()->PropRegistry here so + # that the path is stated in terms of our base + search_path = dict.__getitem__(path, self) - if not strategy_cls: - wc_key = self._wildcard_path - if wc_key and wc_key in context.attributes: - strategy_cls = context.attributes[wc_key] + # search among: exact match, "attr.*", "default" strategy + # if any. + for path_key in ( + search_path._loader_key, + search_path._wildcard_path_loader_key, + search_path._default_path_loader_key + ): + if path_key in context.attributes: + load = context.attributes[path_key] + break - if strategy_cls: - try: - return self._strategies[strategy_cls] - except KeyError: - return self.__init_strategy(strategy_cls) - return self.strategy + return load - def _get_strategy(self, cls): + def _get_strategy(self, key): try: - return self._strategies[cls] + return self._strategies[key] except KeyError: - return self.__init_strategy(cls) + cls = self._strategy_lookup(*key) + self._strategies[key] = self._strategies[cls] = strategy = cls(self) + return strategy - def __init_strategy(self, cls): - self._strategies[cls] = strategy = cls(self) - return strategy + def _get_strategy_by_cls(self, cls): + return self._get_strategy(cls._strategy_keys[0]) def setup(self, context, entity, path, adapter, **kwargs): - self._get_context_strategy(context, path).\ - setup_query(context, entity, path, - adapter, **kwargs) + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + strat.setup_query(context, entity, path, loader, adapter, **kwargs) def create_row_processor(self, context, path, mapper, row, adapter): - return self._get_context_strategy(context, path).\ - create_row_processor(context, path, + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + return strat.create_row_processor(context, path, loader, mapper, row, adapter) def do_init(self): self._strategies = {} - self.strategy = self.__init_strategy(self.strategy_class) + self.strategy = self._get_strategy_by_cls(self.strategy_class) def post_instrument_class(self, mapper): if self.is_primary() and \ @@ -479,17 +484,17 @@ class StrategizedProperty(MapperProperty): _strategies = collections.defaultdict(dict) @classmethod - def _strategy_for(cls, *keys): + def strategy_for(cls, **kw): def decorate(dec_cls): - for key in keys: - key = tuple(sorted(key.items())) - cls._strategies[cls][key] = dec_cls + dec_cls._strategy_keys = [] + key = tuple(sorted(kw.items())) + cls._strategies[cls][key] = dec_cls + dec_cls._strategy_keys.append(key) return dec_cls return decorate @classmethod - def _strategy_lookup(cls, **kw): - key = tuple(sorted(kw.items())) + def _strategy_lookup(cls, *key): for prop_cls in cls.__mro__: if prop_cls in cls._strategies: strategies = cls._strategies[prop_cls] @@ -497,7 +502,7 @@ class StrategizedProperty(MapperProperty): return strategies[key] except KeyError: pass - raise Exception("can't locate strategy for %s %s" % (cls, kw)) + raise Exception("can't locate strategy for %s %s" % (cls, key)) class MapperOption(object): @@ -521,242 +526,6 @@ class MapperOption(object): self.process_query(query) -class PropertyOption(MapperOption): - """A MapperOption that is applied to a property off the mapper or - one of its child mappers, identified by a dot-separated key - or list of class-bound attributes. """ - - def __init__(self, key, mapper=None): - self.key = key - self.mapper = mapper - - def process_query(self, query): - self._process(query, True) - - def process_query_conditionally(self, query): - self._process(query, False) - - def _process(self, query, raiseerr): - paths = self._process_paths(query, raiseerr) - if paths: - self.process_query_property(query, paths) - - def process_query_property(self, query, paths): - pass - - def __getstate__(self): - d = self.__dict__.copy() - d['key'] = ret = [] - for token in util.to_list(self.key): - if isinstance(token, PropComparator): - ret.append((token._parentmapper.class_, token.key)) - else: - ret.append(token) - return d - - def __setstate__(self, state): - ret = [] - for key in state['key']: - if isinstance(key, tuple): - cls, propkey = key - ret.append(getattr(cls, propkey)) - else: - ret.append(key) - state['key'] = tuple(ret) - self.__dict__ = state - - def _find_entity_prop_comparator(self, query, token, mapper, raiseerr): - if _is_aliased_class(mapper): - searchfor = mapper - else: - searchfor = _class_to_mapper(mapper) - for ent in query._mapper_entities: - if ent.corresponds_to(searchfor): - return ent - else: - if raiseerr: - if not list(query._mapper_entities): - raise sa_exc.ArgumentError( - "Query has only expression-based entities - " - "can't find property named '%s'." - % (token, ) - ) - else: - raise sa_exc.ArgumentError( - "Can't find property '%s' on any entity " - "specified in this Query. Note the full path " - "from root (%s) to target entity must be specified." - % (token, ",".join(str(x) for - x in query._mapper_entities)) - ) - else: - return None - - def _find_entity_basestring(self, query, token, raiseerr): - for ent in query._mapper_entities: - # return only the first _MapperEntity when searching - # based on string prop name. Ideally object - # attributes are used to specify more exactly. - return ent - else: - if raiseerr: - raise sa_exc.ArgumentError( - "Query has only expression-based entities - " - "can't find property named '%s'." - % (token, ) - ) - else: - return None - - @util.dependencies("sqlalchemy.orm.util") - def _process_paths(self, orm_util, query, raiseerr): - """reconcile the 'key' for this PropertyOption with - the current path and entities of the query. - - Return a list of affected paths. - - """ - path = PathRegistry.root - entity = None - paths = [] - no_result = [] - - # _current_path implies we're in a - # secondary load with an existing path - current_path = list(query._current_path.path) - - tokens = deque(self.key) - while tokens: - token = tokens.popleft() - if isinstance(token, str): - # wildcard token - if token.endswith(':*'): - return [path.token(token)] - sub_tokens = token.split(".", 1) - token = sub_tokens[0] - tokens.extendleft(sub_tokens[1:]) - - # exhaust current_path before - # matching tokens to entities - if current_path: - if current_path[1].key == token: - current_path = current_path[2:] - continue - else: - return no_result - - if not entity: - entity = self._find_entity_basestring( - query, - token, - raiseerr) - if entity is None: - return no_result - path_element = entity.entity_zero - mapper = entity.mapper - - if hasattr(mapper.class_, token): - prop = getattr(mapper.class_, token).property - else: - if raiseerr: - raise sa_exc.ArgumentError( - "Can't find property named '%s' on the " - "mapped entity %s in this Query. " % ( - token, mapper) - ) - else: - return no_result - elif isinstance(token, PropComparator): - prop = token.property - - # exhaust current_path before - # matching tokens to entities - if current_path: - if current_path[0:2] == \ - [token._parententity, prop]: - current_path = current_path[2:] - continue - else: - return no_result - - if not entity: - entity = self._find_entity_prop_comparator( - query, - prop.key, - token._parententity, - raiseerr) - if not entity: - return no_result - - path_element = entity.entity_zero - mapper = entity.mapper - else: - raise sa_exc.ArgumentError( - "mapper option expects " - "string key or list of attributes") - assert prop is not None - if raiseerr and not prop.parent.common_parent(mapper): - raise sa_exc.ArgumentError("Attribute '%s' does not " - "link from element '%s'" % (token, path_element)) - - path = path[path_element][prop] - - paths.append(path) - - if getattr(token, '_of_type', None): - ac = token._of_type - ext_info = inspect(ac) - path_element = mapper = ext_info.mapper - if not ext_info.is_aliased_class: - ac = orm_util.with_polymorphic( - ext_info.mapper.base_mapper, - ext_info.mapper, aliased=True, - _use_mapper_path=True) - ext_info = inspect(ac) - path.set(query._attributes, "path_with_polymorphic", ext_info) - else: - path_element = mapper = getattr(prop, 'mapper', None) - if mapper is None and tokens: - raise sa_exc.ArgumentError( - "Attribute '%s' of entity '%s' does not " - "refer to a mapped entity" % - (token, entity) - ) - - if current_path: - # ran out of tokens before - # current_path was exhausted. - assert not tokens - return no_result - - return paths - - -class StrategizedOption(PropertyOption): - """A MapperOption that affects which LoaderStrategy will be used - for an operation by a StrategizedProperty. - """ - - chained = False - - def process_query_property(self, query, paths): - strategy = self.get_strategy_class() - if self.chained: - for path in paths: - path.set( - query._attributes, - "loaderstrategy", - strategy - ) - else: - paths[-1].set( - query._attributes, - "loaderstrategy", - strategy - ) - - def get_strategy_class(self): - raise NotImplementedError() class LoaderStrategy(object): @@ -791,10 +560,10 @@ class LoaderStrategy(object): def init_class_attribute(self, mapper): pass - def setup_query(self, context, entity, path, adapter, **kwargs): + def setup_query(self, context, entity, path, loadopt, adapter, **kwargs): pass - def create_row_processor(self, context, path, mapper, + def create_row_processor(self, context, path, loadopt, mapper, row, adapter): """Return row processing functions which fulfill the contract specified by MapperProperty.create_row_processor. |