summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/decl_base.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-01-24 17:04:27 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-02-13 14:23:04 -0500
commite545298e35ea9f126054b337e4b5ba01988b29f7 (patch)
treee64aea159111d5921ff01f08b1c4efb667249dfe /lib/sqlalchemy/orm/decl_base.py
parentf1da1623b800cd4de3b71fd1b2ad5ccfde286780 (diff)
downloadsqlalchemy-e545298e35ea9f126054b337e4b5ba01988b29f7.tar.gz
establish mypy / typing approach for v2.0
large patch to get ORM / typing efforts started. this is to support adding new test cases to mypy, support dropping sqlalchemy2-stubs entirely from the test suite, validate major ORM typing reorganization to eliminate the need for the mypy plugin. * New declarative approach which uses annotation introspection, fixes: #7535 * Mapped[] is now at the base of all ORM constructs that find themselves in classes, to support direct typing without plugins * Mypy plugin updated for new typing structures * Mypy test suite broken out into "plugin" tests vs. "plain" tests, and enhanced to better support test structures where we assert that various objects are introspected by the type checker as we expect. as we go forward with typing, we will add new use cases to "plain" where we can assert that types are introspected as we expect. * For typing support, users will be much more exposed to the class names of things. Add these all to "sqlalchemy" import space. * Column(ForeignKey()) no longer needs to be `@declared_attr` if the FK refers to a remote table * composite() attributes mapped to a dataclass no longer need to implement a `__composite_values__()` method * with_variant() accepts multiple dialect names Change-Id: I22797c0be73a8fbbd2d6f5e0c0b7258b17fe145d Fixes: #7535 Fixes: #7551 References: #6810
Diffstat (limited to 'lib/sqlalchemy/orm/decl_base.py')
-rw-r--r--lib/sqlalchemy/orm/decl_base.py288
1 files changed, 204 insertions, 84 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index fb736806c..342aa772b 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -5,23 +5,34 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Internal implementation for declarative."""
+
+from __future__ import annotations
+
import collections
+from typing import Any
+from typing import Dict
+from typing import Tuple
import weakref
-from sqlalchemy.orm import attributes
-from sqlalchemy.orm import instrumentation
+from . import attributes
from . import clsregistry
from . import exc as orm_exc
+from . import instrumentation
from . import mapperlib
from .attributes import InstrumentedAttribute
from .attributes import QueryableAttribute
from .base import _is_mapped_class
from .base import InspectionAttr
-from .descriptor_props import CompositeProperty
-from .descriptor_props import SynonymProperty
+from .descriptor_props import Composite
+from .descriptor_props import Synonym
+from .interfaces import _IntrospectsAnnotations
+from .interfaces import _MappedAttribute
+from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .mapper import Mapper as mapper
from .properties import ColumnProperty
+from .properties import MappedColumn
+from .util import _is_mapped_annotation
from .util import class_mapper
from .. import event
from .. import exc
@@ -130,7 +141,7 @@ def _mapper(registry, cls, table, mapper_kw):
@util.preload_module("sqlalchemy.orm.decl_api")
-def _is_declarative_props(obj):
+def _is_declarative_props(obj: Any) -> bool:
declared_attr = util.preloaded.orm_decl_api.declared_attr
return isinstance(obj, (declared_attr, util.classproperty))
@@ -208,7 +219,7 @@ class _MapperConfig:
class _ImperativeMapperConfig(_MapperConfig):
- __slots__ = ("dict_", "local_table", "inherits")
+ __slots__ = ("local_table", "inherits")
def __init__(
self,
@@ -221,7 +232,6 @@ class _ImperativeMapperConfig(_MapperConfig):
registry, cls_, mapper_kw
)
- self.dict_ = {}
self.local_table = self.set_cls_attribute("__table__", table)
with mapperlib._CONFIGURE_MUTEX:
@@ -277,7 +287,10 @@ class _ImperativeMapperConfig(_MapperConfig):
class _ClassScanMapperConfig(_MapperConfig):
__slots__ = (
- "dict_",
+ "registry",
+ "clsdict_view",
+ "collected_attributes",
+ "collected_annotations",
"local_table",
"persist_selectable",
"declared_columns",
@@ -299,11 +312,17 @@ class _ClassScanMapperConfig(_MapperConfig):
):
super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw)
-
- self.dict_ = dict(dict_) if dict_ else {}
+ self.registry = registry
self.persist_selectable = None
- self.declared_columns = set()
+
+ self.clsdict_view = (
+ util.immutabledict(dict_) if dict_ else util.EMPTY_DICT
+ )
+ self.collected_attributes = {}
+ self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
+ self.declared_columns = util.OrderedSet()
self.column_copies = {}
+
self._setup_declared_events()
self._scan_attributes()
@@ -407,6 +426,19 @@ class _ClassScanMapperConfig(_MapperConfig):
return attribute_is_overridden
+ _skip_attrs = frozenset(
+ [
+ "__module__",
+ "__annotations__",
+ "__doc__",
+ "__dict__",
+ "__weakref__",
+ "_sa_class_manager",
+ "__dict__",
+ "__weakref__",
+ ]
+ )
+
def _cls_attr_resolver(self, cls):
"""produce a function to iterate the "attributes" of a class,
adjusting for SQLAlchemy fields embedded in dataclass fields.
@@ -416,31 +448,52 @@ class _ClassScanMapperConfig(_MapperConfig):
cls, "__sa_dataclass_metadata_key__", None
)
+ cls_annotations = util.get_annotations(cls)
+
+ cls_vars = vars(cls)
+
+ skip = self._skip_attrs
+
+ names = util.merge_lists_w_ordering(
+ [n for n in cls_vars if n not in skip], list(cls_annotations)
+ )
if sa_dataclass_metadata_key is None:
def local_attributes_for_class():
- for name, obj in vars(cls).items():
- yield name, obj, False
+ return (
+ (
+ name,
+ cls_vars.get(name),
+ cls_annotations.get(name),
+ False,
+ )
+ for name in names
+ )
else:
- field_names = set()
+ dataclass_fields = {
+ field.name: field for field in util.local_dataclass_fields(cls)
+ }
def local_attributes_for_class():
- for field in util.local_dataclass_fields(cls):
- if sa_dataclass_metadata_key in field.metadata:
- field_names.add(field.name)
+ for name in names:
+ field = dataclass_fields.get(name, None)
+ if field and sa_dataclass_metadata_key in field.metadata:
yield field.name, _as_dc_declaredattr(
field.metadata, sa_dataclass_metadata_key
- ), True
- for name, obj in vars(cls).items():
- if name not in field_names:
- yield name, obj, False
+ ), cls_annotations.get(field.name), True
+ else:
+ yield name, cls_vars.get(name), cls_annotations.get(
+ name
+ ), False
return local_attributes_for_class
def _scan_attributes(self):
cls = self.cls
- dict_ = self.dict_
+
+ clsdict_view = self.clsdict_view
+ collected_attributes = self.collected_attributes
column_copies = self.column_copies
mapper_args_fn = None
table_args = inherited_table_args = None
@@ -462,10 +515,16 @@ class _ClassScanMapperConfig(_MapperConfig):
if not class_mapped and base is not cls:
self._produce_column_copies(
- local_attributes_for_class, attribute_is_overridden
+ local_attributes_for_class,
+ attribute_is_overridden,
)
- for name, obj, is_dataclass in local_attributes_for_class():
+ for (
+ name,
+ obj,
+ annotation,
+ is_dataclass,
+ ) in local_attributes_for_class():
if name == "__mapper_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
@@ -514,7 +573,12 @@ class _ClassScanMapperConfig(_MapperConfig):
elif base is not cls:
# we're a mixin, abstract base, or something that is
# acting like that for now.
- if isinstance(obj, Column):
+
+ if isinstance(obj, (Column, MappedColumn)):
+ self.collected_annotations[name] = (
+ annotation,
+ False,
+ )
# already copied columns to the mapped class.
continue
elif isinstance(obj, MapperProperty):
@@ -526,8 +590,12 @@ class _ClassScanMapperConfig(_MapperConfig):
"field() objects, use a lambda:"
)
elif _is_declarative_props(obj):
+ # tried to get overloads to tell this to
+ # pylance, no luck
+ assert obj is not None
+
if obj._cascading:
- if name in dict_:
+ if name in clsdict_view:
# unfortunately, while we can use the user-
# defined attribute here to allow a clean
# override, if there's another
@@ -541,7 +609,7 @@ class _ClassScanMapperConfig(_MapperConfig):
"@declared_attr.cascading; "
"skipping" % (name, cls)
)
- dict_[name] = column_copies[
+ collected_attributes[name] = column_copies[
obj
] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
@@ -579,19 +647,36 @@ class _ClassScanMapperConfig(_MapperConfig):
):
ret = ret.descriptor
- dict_[name] = column_copies[obj] = ret
+ collected_attributes[name] = column_copies[
+ obj
+ ] = ret
if (
isinstance(ret, (Column, MapperProperty))
and ret.doc is None
):
ret.doc = obj.__doc__
- # here, the attribute is some other kind of property that
- # we assume is not part of the declarative mapping.
- # however, check for some more common mistakes
+
+ self.collected_annotations[name] = (
+ obj._collect_return_annotation(),
+ False,
+ )
+ elif _is_mapped_annotation(annotation, cls):
+ self.collected_annotations[name] = (
+ annotation,
+ is_dataclass,
+ )
+ if obj is None:
+ collected_attributes[name] = MappedColumn()
+ else:
+ collected_attributes[name] = obj
else:
+ # here, the attribute is some other kind of
+ # property that we assume is not part of the
+ # declarative mapping. however, check for some
+ # more common mistakes
self._warn_for_decl_attributes(base, name, obj)
elif is_dataclass and (
- name not in dict_ or dict_[name] is not obj
+ name not in clsdict_view or clsdict_view[name] is not obj
):
# here, we are definitely looking at the target class
# and not a superclass. this is currently a
@@ -606,7 +691,20 @@ class _ClassScanMapperConfig(_MapperConfig):
if _is_declarative_props(obj):
obj = obj.fget()
- dict_[name] = obj
+ collected_attributes[name] = obj
+ self.collected_annotations[name] = (
+ annotation,
+ True,
+ )
+ else:
+ self.collected_annotations[name] = (
+ annotation,
+ False,
+ )
+ if obj is None and _is_mapped_annotation(annotation, cls):
+ collected_attributes[name] = MappedColumn()
+ elif name in clsdict_view:
+ collected_attributes[name] = obj
if inherited_table_args and not tablename:
table_args = None
@@ -618,46 +716,55 @@ class _ClassScanMapperConfig(_MapperConfig):
def _warn_for_decl_attributes(self, cls, key, c):
if isinstance(c, expression.ColumnClause):
util.warn(
- "Attribute '%s' on class %s appears to be a non-schema "
- "'sqlalchemy.sql.column()' "
+ f"Attribute '{key}' on class {cls} appears to "
+ "be a non-schema 'sqlalchemy.sql.column()' "
"object; this won't be part of the declarative mapping"
- % (key, cls)
)
def _produce_column_copies(
self, attributes_for_class, attribute_is_overridden
):
cls = self.cls
- dict_ = self.dict_
+ dict_ = self.clsdict_view
+ collected_attributes = self.collected_attributes
column_copies = self.column_copies
# copy mixin columns to the mapped class
- for name, obj, is_dataclass in attributes_for_class():
- if isinstance(obj, Column):
+ for name, obj, annotation, is_dataclass in attributes_for_class():
+ if isinstance(obj, (Column, MappedColumn)):
if attribute_is_overridden(name, obj):
# if column has been overridden
# (like by the InstrumentedAttribute of the
# superclass), skip
continue
- elif obj.foreign_keys:
- raise exc.InvalidRequestError(
- "Columns with foreign keys to other columns "
- "must be declared as @declared_attr callables "
- "on declarative mixin classes. For dataclass "
- "field() objects, use a lambda:."
- )
elif name not in dict_ and not (
"__table__" in dict_
and (obj.name or name) in dict_["__table__"].c
):
+ if obj.foreign_keys:
+ for fk in obj.foreign_keys:
+ if (
+ fk._table_column is not None
+ and fk._table_column.table is None
+ ):
+ raise exc.InvalidRequestError(
+ "Columns with foreign keys to "
+ "non-table-bound "
+ "columns must be declared as "
+ "@declared_attr callables "
+ "on declarative mixin classes. "
+ "For dataclass "
+ "field() objects, use a lambda:."
+ )
+
column_copies[obj] = copy_ = obj._copy()
- copy_._creation_order = obj._creation_order
+ collected_attributes[name] = copy_
+
setattr(cls, name, copy_)
- dict_[name] = copy_
def _extract_mappable_attributes(self):
cls = self.cls
- dict_ = self.dict_
+ collected_attributes = self.collected_attributes
our_stuff = self.properties
@@ -665,13 +772,17 @@ class _ClassScanMapperConfig(_MapperConfig):
cls, "_sa_decl_prepare_nocascade", strict=True
)
- for k in list(dict_):
+ for k in list(collected_attributes):
if k in ("__table__", "__tablename__", "__mapper_args__"):
continue
- value = dict_[k]
+ value = collected_attributes[k]
+
if _is_declarative_props(value):
+ # @declared_attr in collected_attributes only occurs here for a
+ # @declared_attr that's directly on the mapped class;
+ # for a mixin, these have already been evaluated
if value._cascading:
util.warn(
"Use of @declared_attr.cascading only applies to "
@@ -689,13 +800,13 @@ class _ClassScanMapperConfig(_MapperConfig):
):
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
- value = SynonymProperty(value.key)
+ value = Synonym(value.key)
setattr(cls, k, value)
if (
isinstance(value, tuple)
and len(value) == 1
- and isinstance(value[0], (Column, MapperProperty))
+ and isinstance(value[0], (Column, _MappedAttribute))
):
util.warn(
"Ignoring declarative-like tuple value of attribute "
@@ -703,12 +814,12 @@ class _ClassScanMapperConfig(_MapperConfig):
"accidentally placed at the end of the line?" % k
)
continue
- elif not isinstance(value, (Column, MapperProperty)):
+ elif not isinstance(value, (Column, MapperProperty, _MapsColumns)):
# using @declared_attr for some object that
- # isn't Column/MapperProperty; remove from the dict_
+ # isn't Column/MapperProperty; remove from the clsdict_view
# and place the evaluated value onto the class.
if not k.startswith("__"):
- dict_.pop(k)
+ collected_attributes.pop(k)
self._warn_for_decl_attributes(cls, k, value)
if not late_mapped:
setattr(cls, k, value)
@@ -722,27 +833,37 @@ class _ClassScanMapperConfig(_MapperConfig):
"for the MetaData instance when using a "
"declarative base class."
)
+ elif isinstance(value, _IntrospectsAnnotations):
+ annotation, is_dataclass = self.collected_annotations.get(
+ k, (None, None)
+ )
+ value.declarative_scan(
+ self.registry, cls, k, annotation, is_dataclass
+ )
our_stuff[k] = value
def _extract_declared_columns(self):
our_stuff = self.properties
- # set up attributes in the order they were created
- util.sort_dictionary(
- our_stuff, key=lambda key: our_stuff[key]._creation_order
- )
-
# extract columns from the class dict
declared_columns = self.declared_columns
name_to_prop_key = collections.defaultdict(set)
for key, c in list(our_stuff.items()):
- if isinstance(c, (ColumnProperty, CompositeProperty)):
- for col in c.columns:
- if isinstance(col, Column) and col.table is None:
- _undefer_column_name(key, col)
- if not isinstance(c, CompositeProperty):
- name_to_prop_key[col.name].add(key)
- declared_columns.add(col)
+ if isinstance(c, _MapsColumns):
+ for col in c.columns_to_assign:
+ if not isinstance(c, Composite):
+ name_to_prop_key[col.name].add(key)
+ declared_columns.add(col)
+
+ # remove object from the dictionary that will be passed
+ # as mapper(properties={...}) if it is not a MapperProperty
+ # (i.e. this currently means it's a MappedColumn)
+ mp_to_assign = c.mapper_property_to_assign
+ if mp_to_assign:
+ our_stuff[key] = mp_to_assign
+ else:
+ del our_stuff[key]
+
elif isinstance(c, Column):
_undefer_column_name(key, c)
name_to_prop_key[c.name].add(key)
@@ -769,16 +890,12 @@ class _ClassScanMapperConfig(_MapperConfig):
cls = self.cls
tablename = self.tablename
table_args = self.table_args
- dict_ = self.dict_
+ clsdict_view = self.clsdict_view
declared_columns = self.declared_columns
manager = attributes.manager_of_class(cls)
- declared_columns = self.declared_columns = sorted(
- declared_columns, key=lambda c: c._creation_order
- )
-
- if "__table__" not in dict_ and table is None:
+ if "__table__" not in clsdict_view and table is None:
if hasattr(cls, "__table_cls__"):
table_cls = util.unbound_method_to_callable(cls.__table_cls__)
else:
@@ -796,11 +913,11 @@ class _ClassScanMapperConfig(_MapperConfig):
else:
args = table_args
- autoload_with = dict_.get("__autoload_with__")
+ autoload_with = clsdict_view.get("__autoload_with__")
if autoload_with:
table_kw["autoload_with"] = autoload_with
- autoload = dict_.get("__autoload__")
+ autoload = clsdict_view.get("__autoload__")
if autoload:
table_kw["autoload"] = True
@@ -1095,18 +1212,21 @@ def _add_attribute(cls, key, value):
_undefer_column_name(key, value)
cls.__table__.append_column(value, replace_existing=True)
cls.__mapper__.add_property(key, value)
- elif isinstance(value, ColumnProperty):
- for col in value.columns:
- if isinstance(col, Column) and col.table is None:
- _undefer_column_name(key, col)
- cls.__table__.append_column(col, replace_existing=True)
- cls.__mapper__.add_property(key, value)
+ elif isinstance(value, _MapsColumns):
+ mp = value.mapper_property_to_assign
+ for col in value.columns_to_assign:
+ _undefer_column_name(key, col)
+ cls.__table__.append_column(col, replace_existing=True)
+ if not mp:
+ cls.__mapper__.add_property(key, col)
+ if mp:
+ cls.__mapper__.add_property(key, mp)
elif isinstance(value, MapperProperty):
cls.__mapper__.add_property(key, value)
elif isinstance(value, QueryableAttribute) and value.key != key:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
- value = SynonymProperty(value.key)
+ value = Synonym(value.key)
cls.__mapper__.add_property(key, value)
else:
type.__setattr__(cls, key, value)
@@ -1124,7 +1244,7 @@ def _del_attribute(cls, key):
):
value = cls.__dict__[key]
if isinstance(
- value, (Column, ColumnProperty, MapperProperty, QueryableAttribute)
+ value, (Column, _MapsColumns, MapperProperty, QueryableAttribute)
):
raise NotImplementedError(
"Can't un-map individual mapped attributes on a mapped class."