summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/decl_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/decl_base.py')
-rw-r--r--lib/sqlalchemy/orm/decl_base.py288
1 files changed, 204 insertions, 84 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index fb736806c..342aa772b 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -5,23 +5,34 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Internal implementation for declarative."""
+
+from __future__ import annotations
+
import collections
+from typing import Any
+from typing import Dict
+from typing import Tuple
import weakref
-from sqlalchemy.orm import attributes
-from sqlalchemy.orm import instrumentation
+from . import attributes
from . import clsregistry
from . import exc as orm_exc
+from . import instrumentation
from . import mapperlib
from .attributes import InstrumentedAttribute
from .attributes import QueryableAttribute
from .base import _is_mapped_class
from .base import InspectionAttr
-from .descriptor_props import CompositeProperty
-from .descriptor_props import SynonymProperty
+from .descriptor_props import Composite
+from .descriptor_props import Synonym
+from .interfaces import _IntrospectsAnnotations
+from .interfaces import _MappedAttribute
+from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .mapper import Mapper as mapper
from .properties import ColumnProperty
+from .properties import MappedColumn
+from .util import _is_mapped_annotation
from .util import class_mapper
from .. import event
from .. import exc
@@ -130,7 +141,7 @@ def _mapper(registry, cls, table, mapper_kw):
@util.preload_module("sqlalchemy.orm.decl_api")
-def _is_declarative_props(obj):
+def _is_declarative_props(obj: Any) -> bool:
declared_attr = util.preloaded.orm_decl_api.declared_attr
return isinstance(obj, (declared_attr, util.classproperty))
@@ -208,7 +219,7 @@ class _MapperConfig:
class _ImperativeMapperConfig(_MapperConfig):
- __slots__ = ("dict_", "local_table", "inherits")
+ __slots__ = ("local_table", "inherits")
def __init__(
self,
@@ -221,7 +232,6 @@ class _ImperativeMapperConfig(_MapperConfig):
registry, cls_, mapper_kw
)
- self.dict_ = {}
self.local_table = self.set_cls_attribute("__table__", table)
with mapperlib._CONFIGURE_MUTEX:
@@ -277,7 +287,10 @@ class _ImperativeMapperConfig(_MapperConfig):
class _ClassScanMapperConfig(_MapperConfig):
__slots__ = (
- "dict_",
+ "registry",
+ "clsdict_view",
+ "collected_attributes",
+ "collected_annotations",
"local_table",
"persist_selectable",
"declared_columns",
@@ -299,11 +312,17 @@ class _ClassScanMapperConfig(_MapperConfig):
):
super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw)
-
- self.dict_ = dict(dict_) if dict_ else {}
+ self.registry = registry
self.persist_selectable = None
- self.declared_columns = set()
+
+ self.clsdict_view = (
+ util.immutabledict(dict_) if dict_ else util.EMPTY_DICT
+ )
+ self.collected_attributes = {}
+ self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
+ self.declared_columns = util.OrderedSet()
self.column_copies = {}
+
self._setup_declared_events()
self._scan_attributes()
@@ -407,6 +426,19 @@ class _ClassScanMapperConfig(_MapperConfig):
return attribute_is_overridden
+ _skip_attrs = frozenset(
+ [
+ "__module__",
+ "__annotations__",
+ "__doc__",
+ "__dict__",
+ "__weakref__",
+ "_sa_class_manager",
+ "__dict__",
+ "__weakref__",
+ ]
+ )
+
def _cls_attr_resolver(self, cls):
"""produce a function to iterate the "attributes" of a class,
adjusting for SQLAlchemy fields embedded in dataclass fields.
@@ -416,31 +448,52 @@ class _ClassScanMapperConfig(_MapperConfig):
cls, "__sa_dataclass_metadata_key__", None
)
+ cls_annotations = util.get_annotations(cls)
+
+ cls_vars = vars(cls)
+
+ skip = self._skip_attrs
+
+ names = util.merge_lists_w_ordering(
+ [n for n in cls_vars if n not in skip], list(cls_annotations)
+ )
if sa_dataclass_metadata_key is None:
def local_attributes_for_class():
- for name, obj in vars(cls).items():
- yield name, obj, False
+ return (
+ (
+ name,
+ cls_vars.get(name),
+ cls_annotations.get(name),
+ False,
+ )
+ for name in names
+ )
else:
- field_names = set()
+ dataclass_fields = {
+ field.name: field for field in util.local_dataclass_fields(cls)
+ }
def local_attributes_for_class():
- for field in util.local_dataclass_fields(cls):
- if sa_dataclass_metadata_key in field.metadata:
- field_names.add(field.name)
+ for name in names:
+ field = dataclass_fields.get(name, None)
+ if field and sa_dataclass_metadata_key in field.metadata:
yield field.name, _as_dc_declaredattr(
field.metadata, sa_dataclass_metadata_key
- ), True
- for name, obj in vars(cls).items():
- if name not in field_names:
- yield name, obj, False
+ ), cls_annotations.get(field.name), True
+ else:
+ yield name, cls_vars.get(name), cls_annotations.get(
+ name
+ ), False
return local_attributes_for_class
def _scan_attributes(self):
cls = self.cls
- dict_ = self.dict_
+
+ clsdict_view = self.clsdict_view
+ collected_attributes = self.collected_attributes
column_copies = self.column_copies
mapper_args_fn = None
table_args = inherited_table_args = None
@@ -462,10 +515,16 @@ class _ClassScanMapperConfig(_MapperConfig):
if not class_mapped and base is not cls:
self._produce_column_copies(
- local_attributes_for_class, attribute_is_overridden
+ local_attributes_for_class,
+ attribute_is_overridden,
)
- for name, obj, is_dataclass in local_attributes_for_class():
+ for (
+ name,
+ obj,
+ annotation,
+ is_dataclass,
+ ) in local_attributes_for_class():
if name == "__mapper_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
@@ -514,7 +573,12 @@ class _ClassScanMapperConfig(_MapperConfig):
elif base is not cls:
# we're a mixin, abstract base, or something that is
# acting like that for now.
- if isinstance(obj, Column):
+
+ if isinstance(obj, (Column, MappedColumn)):
+ self.collected_annotations[name] = (
+ annotation,
+ False,
+ )
# already copied columns to the mapped class.
continue
elif isinstance(obj, MapperProperty):
@@ -526,8 +590,12 @@ class _ClassScanMapperConfig(_MapperConfig):
"field() objects, use a lambda:"
)
elif _is_declarative_props(obj):
+ # tried to get overloads to tell this to
+ # pylance, no luck
+ assert obj is not None
+
if obj._cascading:
- if name in dict_:
+ if name in clsdict_view:
# unfortunately, while we can use the user-
# defined attribute here to allow a clean
# override, if there's another
@@ -541,7 +609,7 @@ class _ClassScanMapperConfig(_MapperConfig):
"@declared_attr.cascading; "
"skipping" % (name, cls)
)
- dict_[name] = column_copies[
+ collected_attributes[name] = column_copies[
obj
] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
@@ -579,19 +647,36 @@ class _ClassScanMapperConfig(_MapperConfig):
):
ret = ret.descriptor
- dict_[name] = column_copies[obj] = ret
+ collected_attributes[name] = column_copies[
+ obj
+ ] = ret
if (
isinstance(ret, (Column, MapperProperty))
and ret.doc is None
):
ret.doc = obj.__doc__
- # here, the attribute is some other kind of property that
- # we assume is not part of the declarative mapping.
- # however, check for some more common mistakes
+
+ self.collected_annotations[name] = (
+ obj._collect_return_annotation(),
+ False,
+ )
+ elif _is_mapped_annotation(annotation, cls):
+ self.collected_annotations[name] = (
+ annotation,
+ is_dataclass,
+ )
+ if obj is None:
+ collected_attributes[name] = MappedColumn()
+ else:
+ collected_attributes[name] = obj
else:
+ # here, the attribute is some other kind of
+ # property that we assume is not part of the
+ # declarative mapping. however, check for some
+ # more common mistakes
self._warn_for_decl_attributes(base, name, obj)
elif is_dataclass and (
- name not in dict_ or dict_[name] is not obj
+ name not in clsdict_view or clsdict_view[name] is not obj
):
# here, we are definitely looking at the target class
# and not a superclass. this is currently a
@@ -606,7 +691,20 @@ class _ClassScanMapperConfig(_MapperConfig):
if _is_declarative_props(obj):
obj = obj.fget()
- dict_[name] = obj
+ collected_attributes[name] = obj
+ self.collected_annotations[name] = (
+ annotation,
+ True,
+ )
+ else:
+ self.collected_annotations[name] = (
+ annotation,
+ False,
+ )
+ if obj is None and _is_mapped_annotation(annotation, cls):
+ collected_attributes[name] = MappedColumn()
+ elif name in clsdict_view:
+ collected_attributes[name] = obj
if inherited_table_args and not tablename:
table_args = None
@@ -618,46 +716,55 @@ class _ClassScanMapperConfig(_MapperConfig):
def _warn_for_decl_attributes(self, cls, key, c):
if isinstance(c, expression.ColumnClause):
util.warn(
- "Attribute '%s' on class %s appears to be a non-schema "
- "'sqlalchemy.sql.column()' "
+ f"Attribute '{key}' on class {cls} appears to "
+ "be a non-schema 'sqlalchemy.sql.column()' "
"object; this won't be part of the declarative mapping"
- % (key, cls)
)
def _produce_column_copies(
self, attributes_for_class, attribute_is_overridden
):
cls = self.cls
- dict_ = self.dict_
+ dict_ = self.clsdict_view
+ collected_attributes = self.collected_attributes
column_copies = self.column_copies
# copy mixin columns to the mapped class
- for name, obj, is_dataclass in attributes_for_class():
- if isinstance(obj, Column):
+ for name, obj, annotation, is_dataclass in attributes_for_class():
+ if isinstance(obj, (Column, MappedColumn)):
if attribute_is_overridden(name, obj):
# if column has been overridden
# (like by the InstrumentedAttribute of the
# superclass), skip
continue
- elif obj.foreign_keys:
- raise exc.InvalidRequestError(
- "Columns with foreign keys to other columns "
- "must be declared as @declared_attr callables "
- "on declarative mixin classes. For dataclass "
- "field() objects, use a lambda:."
- )
elif name not in dict_ and not (
"__table__" in dict_
and (obj.name or name) in dict_["__table__"].c
):
+ if obj.foreign_keys:
+ for fk in obj.foreign_keys:
+ if (
+ fk._table_column is not None
+ and fk._table_column.table is None
+ ):
+ raise exc.InvalidRequestError(
+ "Columns with foreign keys to "
+ "non-table-bound "
+ "columns must be declared as "
+ "@declared_attr callables "
+ "on declarative mixin classes. "
+ "For dataclass "
+ "field() objects, use a lambda:."
+ )
+
column_copies[obj] = copy_ = obj._copy()
- copy_._creation_order = obj._creation_order
+ collected_attributes[name] = copy_
+
setattr(cls, name, copy_)
- dict_[name] = copy_
def _extract_mappable_attributes(self):
cls = self.cls
- dict_ = self.dict_
+ collected_attributes = self.collected_attributes
our_stuff = self.properties
@@ -665,13 +772,17 @@ class _ClassScanMapperConfig(_MapperConfig):
cls, "_sa_decl_prepare_nocascade", strict=True
)
- for k in list(dict_):
+ for k in list(collected_attributes):
if k in ("__table__", "__tablename__", "__mapper_args__"):
continue
- value = dict_[k]
+ value = collected_attributes[k]
+
if _is_declarative_props(value):
+ # @declared_attr in collected_attributes only occurs here for a
+ # @declared_attr that's directly on the mapped class;
+ # for a mixin, these have already been evaluated
if value._cascading:
util.warn(
"Use of @declared_attr.cascading only applies to "
@@ -689,13 +800,13 @@ class _ClassScanMapperConfig(_MapperConfig):
):
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
- value = SynonymProperty(value.key)
+ value = Synonym(value.key)
setattr(cls, k, value)
if (
isinstance(value, tuple)
and len(value) == 1
- and isinstance(value[0], (Column, MapperProperty))
+ and isinstance(value[0], (Column, _MappedAttribute))
):
util.warn(
"Ignoring declarative-like tuple value of attribute "
@@ -703,12 +814,12 @@ class _ClassScanMapperConfig(_MapperConfig):
"accidentally placed at the end of the line?" % k
)
continue
- elif not isinstance(value, (Column, MapperProperty)):
+ elif not isinstance(value, (Column, MapperProperty, _MapsColumns)):
# using @declared_attr for some object that
- # isn't Column/MapperProperty; remove from the dict_
+ # isn't Column/MapperProperty; remove from the clsdict_view
# and place the evaluated value onto the class.
if not k.startswith("__"):
- dict_.pop(k)
+ collected_attributes.pop(k)
self._warn_for_decl_attributes(cls, k, value)
if not late_mapped:
setattr(cls, k, value)
@@ -722,27 +833,37 @@ class _ClassScanMapperConfig(_MapperConfig):
"for the MetaData instance when using a "
"declarative base class."
)
+ elif isinstance(value, _IntrospectsAnnotations):
+ annotation, is_dataclass = self.collected_annotations.get(
+ k, (None, None)
+ )
+ value.declarative_scan(
+ self.registry, cls, k, annotation, is_dataclass
+ )
our_stuff[k] = value
def _extract_declared_columns(self):
our_stuff = self.properties
- # set up attributes in the order they were created
- util.sort_dictionary(
- our_stuff, key=lambda key: our_stuff[key]._creation_order
- )
-
# extract columns from the class dict
declared_columns = self.declared_columns
name_to_prop_key = collections.defaultdict(set)
for key, c in list(our_stuff.items()):
- if isinstance(c, (ColumnProperty, CompositeProperty)):
- for col in c.columns:
- if isinstance(col, Column) and col.table is None:
- _undefer_column_name(key, col)
- if not isinstance(c, CompositeProperty):
- name_to_prop_key[col.name].add(key)
- declared_columns.add(col)
+ if isinstance(c, _MapsColumns):
+ for col in c.columns_to_assign:
+ if not isinstance(c, Composite):
+ name_to_prop_key[col.name].add(key)
+ declared_columns.add(col)
+
+ # remove object from the dictionary that will be passed
+ # as mapper(properties={...}) if it is not a MapperProperty
+ # (i.e. this currently means it's a MappedColumn)
+ mp_to_assign = c.mapper_property_to_assign
+ if mp_to_assign:
+ our_stuff[key] = mp_to_assign
+ else:
+ del our_stuff[key]
+
elif isinstance(c, Column):
_undefer_column_name(key, c)
name_to_prop_key[c.name].add(key)
@@ -769,16 +890,12 @@ class _ClassScanMapperConfig(_MapperConfig):
cls = self.cls
tablename = self.tablename
table_args = self.table_args
- dict_ = self.dict_
+ clsdict_view = self.clsdict_view
declared_columns = self.declared_columns
manager = attributes.manager_of_class(cls)
- declared_columns = self.declared_columns = sorted(
- declared_columns, key=lambda c: c._creation_order
- )
-
- if "__table__" not in dict_ and table is None:
+ if "__table__" not in clsdict_view and table is None:
if hasattr(cls, "__table_cls__"):
table_cls = util.unbound_method_to_callable(cls.__table_cls__)
else:
@@ -796,11 +913,11 @@ class _ClassScanMapperConfig(_MapperConfig):
else:
args = table_args
- autoload_with = dict_.get("__autoload_with__")
+ autoload_with = clsdict_view.get("__autoload_with__")
if autoload_with:
table_kw["autoload_with"] = autoload_with
- autoload = dict_.get("__autoload__")
+ autoload = clsdict_view.get("__autoload__")
if autoload:
table_kw["autoload"] = True
@@ -1095,18 +1212,21 @@ def _add_attribute(cls, key, value):
_undefer_column_name(key, value)
cls.__table__.append_column(value, replace_existing=True)
cls.__mapper__.add_property(key, value)
- elif isinstance(value, ColumnProperty):
- for col in value.columns:
- if isinstance(col, Column) and col.table is None:
- _undefer_column_name(key, col)
- cls.__table__.append_column(col, replace_existing=True)
- cls.__mapper__.add_property(key, value)
+ elif isinstance(value, _MapsColumns):
+ mp = value.mapper_property_to_assign
+ for col in value.columns_to_assign:
+ _undefer_column_name(key, col)
+ cls.__table__.append_column(col, replace_existing=True)
+ if not mp:
+ cls.__mapper__.add_property(key, col)
+ if mp:
+ cls.__mapper__.add_property(key, mp)
elif isinstance(value, MapperProperty):
cls.__mapper__.add_property(key, value)
elif isinstance(value, QueryableAttribute) and value.key != key:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
- value = SynonymProperty(value.key)
+ value = Synonym(value.key)
cls.__mapper__.add_property(key, value)
else:
type.__setattr__(cls, key, value)
@@ -1124,7 +1244,7 @@ def _del_attribute(cls, key):
):
value = cls.__dict__[key]
if isinstance(
- value, (Column, ColumnProperty, MapperProperty, QueryableAttribute)
+ value, (Column, _MapsColumns, MapperProperty, QueryableAttribute)
):
raise NotImplementedError(
"Can't un-map individual mapped attributes on a mapped class."