diff options
Diffstat (limited to 'lib/sqlalchemy/orm/base.py')
-rw-r--r-- | lib/sqlalchemy/orm/base.py | 77 |
1 files changed, 54 insertions, 23 deletions
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 3fa855a4b..054d52d83 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -26,24 +26,25 @@ from typing import TypeVar from typing import Union from . import exc +from ._typing import insp_is_mapper from .. import exc as sa_exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag from ..util.langhelpers import TypingOnly -from ..util.typing import Concatenate from ..util.typing import Literal -from ..util.typing import ParamSpec from ..util.typing import Self if typing.TYPE_CHECKING: from ._typing import _InternalEntityType from .attributes import InstrumentedAttribute + from .instrumentation import ClassManager from .mapper import Mapper from .state import InstanceState from ..sql._typing import _InfoType + _T = TypeVar("_T", bound=Any) _O = TypeVar("_O", bound=object) @@ -246,21 +247,15 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_Fn = TypeVar("_Fn", bound=Callable) -_Args = ParamSpec("_Args") +_F = TypeVar("_F", bound=Callable) _Self = TypeVar("_Self") def _assertions( *assertions: Any, -) -> Callable[ - [Callable[Concatenate[_Self, _Fn, _Args], _Self]], - Callable[Concatenate[_Self, _Fn, _Args], _Self], -]: +) -> Callable[[_F], _F]: @util.decorator - def generate( - fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs - ) -> _Self: + def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self: for assertion in assertions: assertion(self, fn.__name__) fn(self, *args, **kw) @@ -269,13 +264,13 @@ def _assertions( return generate -# these can be replaced by sqlalchemy.ext.instrumentation -# if augmented class instrumentation is enabled. -def manager_of_class(cls): - return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None) +if TYPE_CHECKING: + def manager_of_class(cls: Type[Any]) -> ClassManager: + ... -if TYPE_CHECKING: + def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]: + ... def instance_state(instance: _O) -> InstanceState[_O]: ... @@ -284,6 +279,20 @@ if TYPE_CHECKING: ... else: + # these can be replaced by sqlalchemy.ext.instrumentation + # if augmented class instrumentation is enabled. + + def manager_of_class(cls): + try: + return cls.__dict__[DEFAULT_MANAGER_ATTR] + except KeyError as ke: + raise exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) from ke + + def opt_manager_of_class(cls): + return cls.__dict__.get(DEFAULT_MANAGER_ATTR) + instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) instance_dict = operator.attrgetter("__dict__") @@ -458,11 +467,12 @@ else: _state_mapper = util.dottedgetter("manager.mapper") -@inspection._inspects(type) -def _inspect_mapped_class(class_, configure=False): +def _inspect_mapped_class( + class_: Type[_O], configure: bool = False +) -> Optional[Mapper[_O]]: try: - class_manager = manager_of_class(class_) - if not class_manager.is_mapped: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: return None mapper = class_manager.mapper except exc.NO_STATE: @@ -473,7 +483,28 @@ def _inspect_mapped_class(class_, configure=False): return mapper -def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: +@inspection._inspects(type) +def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]: + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + return None + else: + return mapper + + +def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]: + insp = inspection.inspect(arg, raiseerr=False) + if insp_is_mapper(insp): + return insp + + raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}") + + +def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]: """Given a class, return the primary :class:`_orm.Mapper` associated with the key. @@ -502,8 +533,8 @@ def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: class InspectionAttr: - """A base class applied to all ORM objects that can be returned - by the :func:`_sa.inspect` function. + """A base class applied to all ORM objects and attributes that are + related to things that can be returned by the :func:`_sa.inspect` function. The attributes defined here allow the usage of simple boolean checks to test basic facts about the object returned. |