summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/base.py')
-rw-r--r--lib/sqlalchemy/orm/base.py77
1 files changed, 54 insertions, 23 deletions
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index 3fa855a4b..054d52d83 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -26,24 +26,25 @@ from typing import TypeVar
from typing import Union
from . import exc
+from ._typing import insp_is_mapper
from .. import exc as sa_exc
from .. import inspection
from .. import util
from ..sql.elements import SQLCoreOperations
from ..util import FastIntFlag
from ..util.langhelpers import TypingOnly
-from ..util.typing import Concatenate
from ..util.typing import Literal
-from ..util.typing import ParamSpec
from ..util.typing import Self
if typing.TYPE_CHECKING:
from ._typing import _InternalEntityType
from .attributes import InstrumentedAttribute
+ from .instrumentation import ClassManager
from .mapper import Mapper
from .state import InstanceState
from ..sql._typing import _InfoType
+
_T = TypeVar("_T", bound=Any)
_O = TypeVar("_O", bound=object)
@@ -246,21 +247,15 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
-_Fn = TypeVar("_Fn", bound=Callable)
-_Args = ParamSpec("_Args")
+_F = TypeVar("_F", bound=Callable)
_Self = TypeVar("_Self")
def _assertions(
*assertions: Any,
-) -> Callable[
- [Callable[Concatenate[_Self, _Fn, _Args], _Self]],
- Callable[Concatenate[_Self, _Fn, _Args], _Self],
-]:
+) -> Callable[[_F], _F]:
@util.decorator
- def generate(
- fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
- ) -> _Self:
+ def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self:
for assertion in assertions:
assertion(self, fn.__name__)
fn(self, *args, **kw)
@@ -269,13 +264,13 @@ def _assertions(
return generate
-# these can be replaced by sqlalchemy.ext.instrumentation
-# if augmented class instrumentation is enabled.
-def manager_of_class(cls):
- return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
+if TYPE_CHECKING:
+ def manager_of_class(cls: Type[Any]) -> ClassManager:
+ ...
-if TYPE_CHECKING:
+ def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]:
+ ...
def instance_state(instance: _O) -> InstanceState[_O]:
...
@@ -284,6 +279,20 @@ if TYPE_CHECKING:
...
else:
+ # these can be replaced by sqlalchemy.ext.instrumentation
+ # if augmented class instrumentation is enabled.
+
+ def manager_of_class(cls):
+ try:
+ return cls.__dict__[DEFAULT_MANAGER_ATTR]
+ except KeyError as ke:
+ raise exc.UnmappedClassError(
+ cls, f"Can't locate an instrumentation manager for class {cls}"
+ ) from ke
+
+ def opt_manager_of_class(cls):
+ return cls.__dict__.get(DEFAULT_MANAGER_ATTR)
+
instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
instance_dict = operator.attrgetter("__dict__")
@@ -458,11 +467,12 @@ else:
_state_mapper = util.dottedgetter("manager.mapper")
-@inspection._inspects(type)
-def _inspect_mapped_class(class_, configure=False):
+def _inspect_mapped_class(
+ class_: Type[_O], configure: bool = False
+) -> Optional[Mapper[_O]]:
try:
- class_manager = manager_of_class(class_)
- if not class_manager.is_mapped:
+ class_manager = opt_manager_of_class(class_)
+ if class_manager is None or not class_manager.is_mapped:
return None
mapper = class_manager.mapper
except exc.NO_STATE:
@@ -473,7 +483,28 @@ def _inspect_mapped_class(class_, configure=False):
return mapper
-def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]:
+@inspection._inspects(type)
+def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]:
+ try:
+ class_manager = opt_manager_of_class(class_)
+ if class_manager is None or not class_manager.is_mapped:
+ return None
+ mapper = class_manager.mapper
+ except exc.NO_STATE:
+ return None
+ else:
+ return mapper
+
+
+def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]:
+ insp = inspection.inspect(arg, raiseerr=False)
+ if insp_is_mapper(insp):
+ return insp
+
+ raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}")
+
+
+def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]:
"""Given a class, return the primary :class:`_orm.Mapper` associated
with the key.
@@ -502,8 +533,8 @@ def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]:
class InspectionAttr:
- """A base class applied to all ORM objects that can be returned
- by the :func:`_sa.inspect` function.
+ """A base class applied to all ORM objects and attributes that are
+ related to things that can be returned by the :func:`_sa.inspect` function.
The attributes defined here allow the usage of simple boolean
checks to test basic facts about the object returned.