summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/horizontal_shard.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py138
1 files changed, 80 insertions, 58 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index 69767ad6c..fd53c6046 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -13,11 +13,14 @@ distribute queries and persistence operations across multiple databases.
For a usage example, see the :ref:`examples_sharding` example included in
the source distribution.
-.. legacy:: The horizontal sharding API is not fully updated for the
- SQLAlchemy 2.0 API, and still relies in part on the
- legacy :class:`.Query` architecture, in particular as part of the
- signature for the :paramref:`.ShardedSession.id_chooser` parameter.
- This may change in a future release.
+.. deepalchemy:: The horizontal sharding extension is an advanced feature,
+ involving a complex statement -> database interaction as well as
+ use of semi-public APIs for non-trivial cases. Simpler approaches to
+ refering to multiple database "shards", most commonly using a distinct
+ :class:`_orm.Session` per "shard", should always be considered first
+ before using this more complex and less-production-tested system.
+
+
"""
from __future__ import annotations
@@ -38,8 +41,11 @@ from .. import exc
from .. import inspect
from .. import util
from ..orm import PassiveFlag
+from ..orm._typing import OrmExecuteOptionsParameter
from ..orm.mapper import Mapper
from ..orm.query import Query
+from ..orm.session import _BindArguments
+from ..orm.session import _PKIdentityArgument
from ..orm.session import Session
from ..util.typing import Protocol
@@ -80,6 +86,20 @@ class ShardChooser(Protocol):
...
+class IdentityChooser(Protocol):
+ def __call__(
+ self,
+ mapper: Mapper[_T],
+ primary_key: _PKIdentityArgument,
+ *,
+ lazy_loaded_from: Optional[InstanceState[Any]],
+ execution_options: OrmExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ **kw: Any,
+ ) -> Any:
+ ...
+
+
class ShardedQuery(Query[_T]):
"""Query class used with :class:`.ShardedSession`.
@@ -94,8 +114,7 @@ class ShardedQuery(Query[_T]):
super().__init__(*args, **kwargs)
assert isinstance(self.session, ShardedSession)
- self.id_chooser = self.session.id_chooser
- self.query_chooser = self.session.query_chooser
+ self.identity_chooser = self.session.identity_chooser
self.execute_chooser = self.session.execute_chooser
self._shard_id = None
@@ -119,19 +138,22 @@ class ShardedQuery(Query[_T]):
class ShardedSession(Session):
shard_chooser: ShardChooser
- id_chooser: Callable[[Query[Any], Iterable[Any]], Iterable[Any]]
+ identity_chooser: IdentityChooser
execute_chooser: Callable[[ORMExecuteState], Iterable[Any]]
def __init__(
self,
shard_chooser: ShardChooser,
- id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]],
+ identity_chooser: Optional[IdentityChooser] = None,
execute_chooser: Optional[
Callable[[ORMExecuteState], Iterable[Any]]
] = None,
shards: Optional[Dict[str, Any]] = None,
query_cls: Type[Query[_T]] = ShardedQuery,
*,
+ id_chooser: Optional[
+ Callable[[Query[_T], Iterable[_T]], Iterable[Any]]
+ ] = None,
query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None,
**kwargs: Any,
) -> None:
@@ -171,12 +193,41 @@ class ShardedSession(Session):
self, "do_orm_execute", execute_and_instances, retval=True
)
self.shard_chooser = shard_chooser
- self.id_chooser = id_chooser
+
+ if id_chooser:
+ _id_chooser = id_chooser
+ util.warn_deprecated(
+ "The ``id_chooser`` parameter is deprecated; "
+ "please use ``identity_chooser``.",
+ "2.0",
+ )
+
+ def _legacy_identity_chooser(
+ mapper: Mapper[_T],
+ primary_key: _PKIdentityArgument,
+ *,
+ lazy_loaded_from: Optional[InstanceState[Any]],
+ execution_options: OrmExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ **kw: Any,
+ ) -> Any:
+ q = self.query(mapper)
+ if lazy_loaded_from:
+ q = q._set_lazyload_from(lazy_loaded_from)
+ return _id_chooser(q, primary_key)
+
+ self.identity_chooser = _legacy_identity_chooser
+ elif identity_chooser:
+ self.identity_chooser = identity_chooser
+ else:
+ raise exc.ArgumentError(
+ "identity_chooser or id_chooser is required"
+ )
if query_chooser:
_query_chooser = query_chooser
util.warn_deprecated(
- "The ``query_choser`` parameter is deprecated; "
+ "The ``query_chooser`` parameter is deprecated; "
"please use ``execute_chooser``.",
"1.4",
)
@@ -199,7 +250,6 @@ class ShardedSession(Session):
"execute_chooser or query_chooser is required"
)
self.execute_chooser = execute_chooser
- self.query_chooser = query_chooser
self.__shards: Dict[_ShardKey, _SessionBind] = {}
if shards is not None:
for k in shards:
@@ -212,6 +262,8 @@ class ShardedSession(Session):
identity_token: Optional[Any] = None,
passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
lazy_loaded_from: Optional[InstanceState[Any]] = None,
+ execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Union[Optional[_O], LoaderCallableStatus]:
"""override the default :meth:`.Session._identity_lookup` method so
@@ -233,10 +285,13 @@ class ShardedSession(Session):
return obj
else:
- q = self.query(mapper)
- if lazy_loaded_from:
- q = q._set_lazyload_from(lazy_loaded_from)
- for shard_id in self.id_chooser(q, primary_key_identity):
+ for shard_id in self.identity_chooser(
+ mapper,
+ primary_key_identity,
+ lazy_loaded_from=lazy_loaded_from,
+ execution_options=execution_options,
+ bind_arguments=dict(bind_arguments) if bind_arguments else {},
+ ):
obj2 = super()._identity_lookup(
mapper,
primary_key_identity,
@@ -325,11 +380,6 @@ class ShardedSession(Session):
def execute_and_instances(
orm_context: ORMExecuteState,
) -> Union[Result[_T], IteratorResult[_TP]]:
- update_options: Union[
- None,
- BulkUDCompileState.default_update_options,
- Type[BulkUDCompileState.default_update_options],
- ]
active_options: Union[
None,
QueryContext.default_load_options,
@@ -337,58 +387,30 @@ def execute_and_instances(
BulkUDCompileState.default_update_options,
Type[BulkUDCompileState.default_update_options],
]
- load_options: Union[
- None,
- QueryContext.default_load_options,
- Type[QueryContext.default_load_options],
- ]
if orm_context.is_select:
- load_options = active_options = orm_context.load_options
- update_options = None
+ active_options = orm_context.load_options
elif orm_context.is_update or orm_context.is_delete:
- load_options = None
- update_options = active_options = orm_context.update_delete_options
+ active_options = orm_context.update_delete_options
else:
- load_options = update_options = active_options = None
+ active_options = None
session = orm_context.session
assert isinstance(session, ShardedSession)
def iter_for_shard(
shard_id: str,
- load_options: Union[
- None,
- QueryContext.default_load_options,
- Type[QueryContext.default_load_options],
- ],
- update_options: Union[
- None,
- BulkUDCompileState.default_update_options,
- Type[BulkUDCompileState.default_update_options],
- ],
) -> Union[Result[_T], IteratorResult[_TP]]:
- execution_options = dict(orm_context.local_execution_options)
bind_arguments = dict(orm_context.bind_arguments)
bind_arguments["shard_id"] = shard_id
- if orm_context.is_select:
- assert load_options is not None
- load_options += {"_refresh_identity_token": shard_id}
- execution_options["_sa_orm_load_options"] = load_options
- elif orm_context.is_update or orm_context.is_delete:
- assert update_options is not None
- update_options += {"_refresh_identity_token": shard_id}
- execution_options["_sa_orm_update_options"] = update_options
-
- return orm_context.invoke_statement(
- bind_arguments=bind_arguments, execution_options=execution_options
- )
+ orm_context.update_execution_options(identity_token=shard_id)
+ return orm_context.invoke_statement(bind_arguments=bind_arguments)
- if active_options and active_options._refresh_identity_token is not None:
- shard_id = active_options._refresh_identity_token
+ if active_options and active_options._identity_token is not None:
+ shard_id = active_options._identity_token
elif "_sa_shard_id" in orm_context.execution_options:
shard_id = orm_context.execution_options["_sa_shard_id"]
elif "shard_id" in orm_context.bind_arguments:
@@ -397,10 +419,10 @@ def execute_and_instances(
shard_id = None
if shard_id is not None:
- return iter_for_shard(shard_id, load_options, update_options)
+ return iter_for_shard(shard_id)
else:
partial = []
for shard_id in session.execute_chooser(orm_context):
- result_ = iter_for_shard(shard_id, load_options, update_options)
+ result_ = iter_for_shard(shard_id)
partial.append(result_)
return partial[0].merge(*partial[1:])