diff options
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
-rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 138 |
1 files changed, 80 insertions, 58 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 69767ad6c..fd53c6046 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -13,11 +13,14 @@ distribute queries and persistence operations across multiple databases. For a usage example, see the :ref:`examples_sharding` example included in the source distribution. -.. legacy:: The horizontal sharding API is not fully updated for the - SQLAlchemy 2.0 API, and still relies in part on the - legacy :class:`.Query` architecture, in particular as part of the - signature for the :paramref:`.ShardedSession.id_chooser` parameter. - This may change in a future release. +.. deepalchemy:: The horizontal sharding extension is an advanced feature, + involving a complex statement -> database interaction as well as + use of semi-public APIs for non-trivial cases. Simpler approaches to + refering to multiple database "shards", most commonly using a distinct + :class:`_orm.Session` per "shard", should always be considered first + before using this more complex and less-production-tested system. + + """ from __future__ import annotations @@ -38,8 +41,11 @@ from .. import exc from .. import inspect from .. import util from ..orm import PassiveFlag +from ..orm._typing import OrmExecuteOptionsParameter from ..orm.mapper import Mapper from ..orm.query import Query +from ..orm.session import _BindArguments +from ..orm.session import _PKIdentityArgument from ..orm.session import Session from ..util.typing import Protocol @@ -80,6 +86,20 @@ class ShardChooser(Protocol): ... +class IdentityChooser(Protocol): + def __call__( + self, + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + ... + + class ShardedQuery(Query[_T]): """Query class used with :class:`.ShardedSession`. @@ -94,8 +114,7 @@ class ShardedQuery(Query[_T]): super().__init__(*args, **kwargs) assert isinstance(self.session, ShardedSession) - self.id_chooser = self.session.id_chooser - self.query_chooser = self.session.query_chooser + self.identity_chooser = self.session.identity_chooser self.execute_chooser = self.session.execute_chooser self._shard_id = None @@ -119,19 +138,22 @@ class ShardedQuery(Query[_T]): class ShardedSession(Session): shard_chooser: ShardChooser - id_chooser: Callable[[Query[Any], Iterable[Any]], Iterable[Any]] + identity_chooser: IdentityChooser execute_chooser: Callable[[ORMExecuteState], Iterable[Any]] def __init__( self, shard_chooser: ShardChooser, - id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]], + identity_chooser: Optional[IdentityChooser] = None, execute_chooser: Optional[ Callable[[ORMExecuteState], Iterable[Any]] ] = None, shards: Optional[Dict[str, Any]] = None, query_cls: Type[Query[_T]] = ShardedQuery, *, + id_chooser: Optional[ + Callable[[Query[_T], Iterable[_T]], Iterable[Any]] + ] = None, query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None, **kwargs: Any, ) -> None: @@ -171,12 +193,41 @@ class ShardedSession(Session): self, "do_orm_execute", execute_and_instances, retval=True ) self.shard_chooser = shard_chooser - self.id_chooser = id_chooser + + if id_chooser: + _id_chooser = id_chooser + util.warn_deprecated( + "The ``id_chooser`` parameter is deprecated; " + "please use ``identity_chooser``.", + "2.0", + ) + + def _legacy_identity_chooser( + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + q = self.query(mapper) + if lazy_loaded_from: + q = q._set_lazyload_from(lazy_loaded_from) + return _id_chooser(q, primary_key) + + self.identity_chooser = _legacy_identity_chooser + elif identity_chooser: + self.identity_chooser = identity_chooser + else: + raise exc.ArgumentError( + "identity_chooser or id_chooser is required" + ) if query_chooser: _query_chooser = query_chooser util.warn_deprecated( - "The ``query_choser`` parameter is deprecated; " + "The ``query_chooser`` parameter is deprecated; " "please use ``execute_chooser``.", "1.4", ) @@ -199,7 +250,6 @@ class ShardedSession(Session): "execute_chooser or query_chooser is required" ) self.execute_chooser = execute_chooser - self.query_chooser = query_chooser self.__shards: Dict[_ShardKey, _SessionBind] = {} if shards is not None: for k in shards: @@ -212,6 +262,8 @@ class ShardedSession(Session): identity_token: Optional[Any] = None, passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, **kw: Any, ) -> Union[Optional[_O], LoaderCallableStatus]: """override the default :meth:`.Session._identity_lookup` method so @@ -233,10 +285,13 @@ class ShardedSession(Session): return obj else: - q = self.query(mapper) - if lazy_loaded_from: - q = q._set_lazyload_from(lazy_loaded_from) - for shard_id in self.id_chooser(q, primary_key_identity): + for shard_id in self.identity_chooser( + mapper, + primary_key_identity, + lazy_loaded_from=lazy_loaded_from, + execution_options=execution_options, + bind_arguments=dict(bind_arguments) if bind_arguments else {}, + ): obj2 = super()._identity_lookup( mapper, primary_key_identity, @@ -325,11 +380,6 @@ class ShardedSession(Session): def execute_and_instances( orm_context: ORMExecuteState, ) -> Union[Result[_T], IteratorResult[_TP]]: - update_options: Union[ - None, - BulkUDCompileState.default_update_options, - Type[BulkUDCompileState.default_update_options], - ] active_options: Union[ None, QueryContext.default_load_options, @@ -337,58 +387,30 @@ def execute_and_instances( BulkUDCompileState.default_update_options, Type[BulkUDCompileState.default_update_options], ] - load_options: Union[ - None, - QueryContext.default_load_options, - Type[QueryContext.default_load_options], - ] if orm_context.is_select: - load_options = active_options = orm_context.load_options - update_options = None + active_options = orm_context.load_options elif orm_context.is_update or orm_context.is_delete: - load_options = None - update_options = active_options = orm_context.update_delete_options + active_options = orm_context.update_delete_options else: - load_options = update_options = active_options = None + active_options = None session = orm_context.session assert isinstance(session, ShardedSession) def iter_for_shard( shard_id: str, - load_options: Union[ - None, - QueryContext.default_load_options, - Type[QueryContext.default_load_options], - ], - update_options: Union[ - None, - BulkUDCompileState.default_update_options, - Type[BulkUDCompileState.default_update_options], - ], ) -> Union[Result[_T], IteratorResult[_TP]]: - execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) bind_arguments["shard_id"] = shard_id - if orm_context.is_select: - assert load_options is not None - load_options += {"_refresh_identity_token": shard_id} - execution_options["_sa_orm_load_options"] = load_options - elif orm_context.is_update or orm_context.is_delete: - assert update_options is not None - update_options += {"_refresh_identity_token": shard_id} - execution_options["_sa_orm_update_options"] = update_options - - return orm_context.invoke_statement( - bind_arguments=bind_arguments, execution_options=execution_options - ) + orm_context.update_execution_options(identity_token=shard_id) + return orm_context.invoke_statement(bind_arguments=bind_arguments) - if active_options and active_options._refresh_identity_token is not None: - shard_id = active_options._refresh_identity_token + if active_options and active_options._identity_token is not None: + shard_id = active_options._identity_token elif "_sa_shard_id" in orm_context.execution_options: shard_id = orm_context.execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: @@ -397,10 +419,10 @@ def execute_and_instances( shard_id = None if shard_id is not None: - return iter_for_shard(shard_id, load_options, update_options) + return iter_for_shard(shard_id) else: partial = [] for shard_id in session.execute_chooser(orm_context): - result_ = iter_for_shard(shard_id, load_options, update_options) + result_ = iter_for_shard(shard_id) partial.append(result_) return partial[0].merge(*partial[1:]) |