diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-01-24 17:04:27 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-02-13 14:23:04 -0500 |
commit | e545298e35ea9f126054b337e4b5ba01988b29f7 (patch) | |
tree | e64aea159111d5921ff01f08b1c4efb667249dfe /lib/sqlalchemy/orm/decl_base.py | |
parent | f1da1623b800cd4de3b71fd1b2ad5ccfde286780 (diff) | |
download | sqlalchemy-e545298e35ea9f126054b337e4b5ba01988b29f7.tar.gz |
establish mypy / typing approach for v2.0
large patch to get ORM / typing efforts started.
this is to support adding new test cases to mypy,
support dropping sqlalchemy2-stubs entirely from the
test suite, validate major ORM typing reorganization
to eliminate the need for the mypy plugin.
* New declarative approach which uses annotation
introspection, fixes: #7535
* Mapped[] is now at the base of all ORM constructs
that find themselves in classes, to support direct
typing without plugins
* Mypy plugin updated for new typing structures
* Mypy test suite broken out into "plugin" tests vs.
"plain" tests, and enhanced to better support test
structures where we assert that various objects are
introspected by the type checker as we expect.
as we go forward with typing, we will
add new use cases to "plain" where we can assert that
types are introspected as we expect.
* For typing support, users will be much more exposed to the
class names of things. Add these all to "sqlalchemy" import
space.
* Column(ForeignKey()) no longer needs to be `@declared_attr`
if the FK refers to a remote table
* composite() attributes mapped to a dataclass no longer
need to implement a `__composite_values__()` method
* with_variant() accepts multiple dialect names
Change-Id: I22797c0be73a8fbbd2d6f5e0c0b7258b17fe145d
Fixes: #7535
Fixes: #7551
References: #6810
Diffstat (limited to 'lib/sqlalchemy/orm/decl_base.py')
-rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 288 |
1 files changed, 204 insertions, 84 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index fb736806c..342aa772b 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -5,23 +5,34 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php """Internal implementation for declarative.""" + +from __future__ import annotations + import collections +from typing import Any +from typing import Dict +from typing import Tuple import weakref -from sqlalchemy.orm import attributes -from sqlalchemy.orm import instrumentation +from . import attributes from . import clsregistry from . import exc as orm_exc +from . import instrumentation from . import mapperlib from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .base import _is_mapped_class from .base import InspectionAttr -from .descriptor_props import CompositeProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MappedAttribute +from .interfaces import _MapsColumns from .interfaces import MapperProperty from .mapper import Mapper as mapper from .properties import ColumnProperty +from .properties import MappedColumn +from .util import _is_mapped_annotation from .util import class_mapper from .. import event from .. import exc @@ -130,7 +141,7 @@ def _mapper(registry, cls, table, mapper_kw): @util.preload_module("sqlalchemy.orm.decl_api") -def _is_declarative_props(obj): +def _is_declarative_props(obj: Any) -> bool: declared_attr = util.preloaded.orm_decl_api.declared_attr return isinstance(obj, (declared_attr, util.classproperty)) @@ -208,7 +219,7 @@ class _MapperConfig: class _ImperativeMapperConfig(_MapperConfig): - __slots__ = ("dict_", "local_table", "inherits") + __slots__ = ("local_table", "inherits") def __init__( self, @@ -221,7 +232,6 @@ class _ImperativeMapperConfig(_MapperConfig): registry, cls_, mapper_kw ) - self.dict_ = {} self.local_table = self.set_cls_attribute("__table__", table) with mapperlib._CONFIGURE_MUTEX: @@ -277,7 +287,10 @@ class _ImperativeMapperConfig(_MapperConfig): class _ClassScanMapperConfig(_MapperConfig): __slots__ = ( - "dict_", + "registry", + "clsdict_view", + "collected_attributes", + "collected_annotations", "local_table", "persist_selectable", "declared_columns", @@ -299,11 +312,17 @@ class _ClassScanMapperConfig(_MapperConfig): ): super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw) - - self.dict_ = dict(dict_) if dict_ else {} + self.registry = registry self.persist_selectable = None - self.declared_columns = set() + + self.clsdict_view = ( + util.immutabledict(dict_) if dict_ else util.EMPTY_DICT + ) + self.collected_attributes = {} + self.collected_annotations: Dict[str, Tuple[Any, bool]] = {} + self.declared_columns = util.OrderedSet() self.column_copies = {} + self._setup_declared_events() self._scan_attributes() @@ -407,6 +426,19 @@ class _ClassScanMapperConfig(_MapperConfig): return attribute_is_overridden + _skip_attrs = frozenset( + [ + "__module__", + "__annotations__", + "__doc__", + "__dict__", + "__weakref__", + "_sa_class_manager", + "__dict__", + "__weakref__", + ] + ) + def _cls_attr_resolver(self, cls): """produce a function to iterate the "attributes" of a class, adjusting for SQLAlchemy fields embedded in dataclass fields. @@ -416,31 +448,52 @@ class _ClassScanMapperConfig(_MapperConfig): cls, "__sa_dataclass_metadata_key__", None ) + cls_annotations = util.get_annotations(cls) + + cls_vars = vars(cls) + + skip = self._skip_attrs + + names = util.merge_lists_w_ordering( + [n for n in cls_vars if n not in skip], list(cls_annotations) + ) if sa_dataclass_metadata_key is None: def local_attributes_for_class(): - for name, obj in vars(cls).items(): - yield name, obj, False + return ( + ( + name, + cls_vars.get(name), + cls_annotations.get(name), + False, + ) + for name in names + ) else: - field_names = set() + dataclass_fields = { + field.name: field for field in util.local_dataclass_fields(cls) + } def local_attributes_for_class(): - for field in util.local_dataclass_fields(cls): - if sa_dataclass_metadata_key in field.metadata: - field_names.add(field.name) + for name in names: + field = dataclass_fields.get(name, None) + if field and sa_dataclass_metadata_key in field.metadata: yield field.name, _as_dc_declaredattr( field.metadata, sa_dataclass_metadata_key - ), True - for name, obj in vars(cls).items(): - if name not in field_names: - yield name, obj, False + ), cls_annotations.get(field.name), True + else: + yield name, cls_vars.get(name), cls_annotations.get( + name + ), False return local_attributes_for_class def _scan_attributes(self): cls = self.cls - dict_ = self.dict_ + + clsdict_view = self.clsdict_view + collected_attributes = self.collected_attributes column_copies = self.column_copies mapper_args_fn = None table_args = inherited_table_args = None @@ -462,10 +515,16 @@ class _ClassScanMapperConfig(_MapperConfig): if not class_mapped and base is not cls: self._produce_column_copies( - local_attributes_for_class, attribute_is_overridden + local_attributes_for_class, + attribute_is_overridden, ) - for name, obj, is_dataclass in local_attributes_for_class(): + for ( + name, + obj, + annotation, + is_dataclass, + ) in local_attributes_for_class(): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -514,7 +573,12 @@ class _ClassScanMapperConfig(_MapperConfig): elif base is not cls: # we're a mixin, abstract base, or something that is # acting like that for now. - if isinstance(obj, Column): + + if isinstance(obj, (Column, MappedColumn)): + self.collected_annotations[name] = ( + annotation, + False, + ) # already copied columns to the mapped class. continue elif isinstance(obj, MapperProperty): @@ -526,8 +590,12 @@ class _ClassScanMapperConfig(_MapperConfig): "field() objects, use a lambda:" ) elif _is_declarative_props(obj): + # tried to get overloads to tell this to + # pylance, no luck + assert obj is not None + if obj._cascading: - if name in dict_: + if name in clsdict_view: # unfortunately, while we can use the user- # defined attribute here to allow a clean # override, if there's another @@ -541,7 +609,7 @@ class _ClassScanMapperConfig(_MapperConfig): "@declared_attr.cascading; " "skipping" % (name, cls) ) - dict_[name] = column_copies[ + collected_attributes[name] = column_copies[ obj ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) @@ -579,19 +647,36 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret = ret.descriptor - dict_[name] = column_copies[obj] = ret + collected_attributes[name] = column_copies[ + obj + ] = ret if ( isinstance(ret, (Column, MapperProperty)) and ret.doc is None ): ret.doc = obj.__doc__ - # here, the attribute is some other kind of property that - # we assume is not part of the declarative mapping. - # however, check for some more common mistakes + + self.collected_annotations[name] = ( + obj._collect_return_annotation(), + False, + ) + elif _is_mapped_annotation(annotation, cls): + self.collected_annotations[name] = ( + annotation, + is_dataclass, + ) + if obj is None: + collected_attributes[name] = MappedColumn() + else: + collected_attributes[name] = obj else: + # here, the attribute is some other kind of + # property that we assume is not part of the + # declarative mapping. however, check for some + # more common mistakes self._warn_for_decl_attributes(base, name, obj) elif is_dataclass and ( - name not in dict_ or dict_[name] is not obj + name not in clsdict_view or clsdict_view[name] is not obj ): # here, we are definitely looking at the target class # and not a superclass. this is currently a @@ -606,7 +691,20 @@ class _ClassScanMapperConfig(_MapperConfig): if _is_declarative_props(obj): obj = obj.fget() - dict_[name] = obj + collected_attributes[name] = obj + self.collected_annotations[name] = ( + annotation, + True, + ) + else: + self.collected_annotations[name] = ( + annotation, + False, + ) + if obj is None and _is_mapped_annotation(annotation, cls): + collected_attributes[name] = MappedColumn() + elif name in clsdict_view: + collected_attributes[name] = obj if inherited_table_args and not tablename: table_args = None @@ -618,46 +716,55 @@ class _ClassScanMapperConfig(_MapperConfig): def _warn_for_decl_attributes(self, cls, key, c): if isinstance(c, expression.ColumnClause): util.warn( - "Attribute '%s' on class %s appears to be a non-schema " - "'sqlalchemy.sql.column()' " + f"Attribute '{key}' on class {cls} appears to " + "be a non-schema 'sqlalchemy.sql.column()' " "object; this won't be part of the declarative mapping" - % (key, cls) ) def _produce_column_copies( self, attributes_for_class, attribute_is_overridden ): cls = self.cls - dict_ = self.dict_ + dict_ = self.clsdict_view + collected_attributes = self.collected_attributes column_copies = self.column_copies # copy mixin columns to the mapped class - for name, obj, is_dataclass in attributes_for_class(): - if isinstance(obj, Column): + for name, obj, annotation, is_dataclass in attributes_for_class(): + if isinstance(obj, (Column, MappedColumn)): if attribute_is_overridden(name, obj): # if column has been overridden # (like by the InstrumentedAttribute of the # superclass), skip continue - elif obj.foreign_keys: - raise exc.InvalidRequestError( - "Columns with foreign keys to other columns " - "must be declared as @declared_attr callables " - "on declarative mixin classes. For dataclass " - "field() objects, use a lambda:." - ) elif name not in dict_ and not ( "__table__" in dict_ and (obj.name or name) in dict_["__table__"].c ): + if obj.foreign_keys: + for fk in obj.foreign_keys: + if ( + fk._table_column is not None + and fk._table_column.table is None + ): + raise exc.InvalidRequestError( + "Columns with foreign keys to " + "non-table-bound " + "columns must be declared as " + "@declared_attr callables " + "on declarative mixin classes. " + "For dataclass " + "field() objects, use a lambda:." + ) + column_copies[obj] = copy_ = obj._copy() - copy_._creation_order = obj._creation_order + collected_attributes[name] = copy_ + setattr(cls, name, copy_) - dict_[name] = copy_ def _extract_mappable_attributes(self): cls = self.cls - dict_ = self.dict_ + collected_attributes = self.collected_attributes our_stuff = self.properties @@ -665,13 +772,17 @@ class _ClassScanMapperConfig(_MapperConfig): cls, "_sa_decl_prepare_nocascade", strict=True ) - for k in list(dict_): + for k in list(collected_attributes): if k in ("__table__", "__tablename__", "__mapper_args__"): continue - value = dict_[k] + value = collected_attributes[k] + if _is_declarative_props(value): + # @declared_attr in collected_attributes only occurs here for a + # @declared_attr that's directly on the mapped class; + # for a mixin, these have already been evaluated if value._cascading: util.warn( "Use of @declared_attr.cascading only applies to " @@ -689,13 +800,13 @@ class _ClassScanMapperConfig(_MapperConfig): ): # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() - value = SynonymProperty(value.key) + value = Synonym(value.key) setattr(cls, k, value) if ( isinstance(value, tuple) and len(value) == 1 - and isinstance(value[0], (Column, MapperProperty)) + and isinstance(value[0], (Column, _MappedAttribute)) ): util.warn( "Ignoring declarative-like tuple value of attribute " @@ -703,12 +814,12 @@ class _ClassScanMapperConfig(_MapperConfig): "accidentally placed at the end of the line?" % k ) continue - elif not isinstance(value, (Column, MapperProperty)): + elif not isinstance(value, (Column, MapperProperty, _MapsColumns)): # using @declared_attr for some object that - # isn't Column/MapperProperty; remove from the dict_ + # isn't Column/MapperProperty; remove from the clsdict_view # and place the evaluated value onto the class. if not k.startswith("__"): - dict_.pop(k) + collected_attributes.pop(k) self._warn_for_decl_attributes(cls, k, value) if not late_mapped: setattr(cls, k, value) @@ -722,27 +833,37 @@ class _ClassScanMapperConfig(_MapperConfig): "for the MetaData instance when using a " "declarative base class." ) + elif isinstance(value, _IntrospectsAnnotations): + annotation, is_dataclass = self.collected_annotations.get( + k, (None, None) + ) + value.declarative_scan( + self.registry, cls, k, annotation, is_dataclass + ) our_stuff[k] = value def _extract_declared_columns(self): our_stuff = self.properties - # set up attributes in the order they were created - util.sort_dictionary( - our_stuff, key=lambda key: our_stuff[key]._creation_order - ) - # extract columns from the class dict declared_columns = self.declared_columns name_to_prop_key = collections.defaultdict(set) for key, c in list(our_stuff.items()): - if isinstance(c, (ColumnProperty, CompositeProperty)): - for col in c.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - if not isinstance(c, CompositeProperty): - name_to_prop_key[col.name].add(key) - declared_columns.add(col) + if isinstance(c, _MapsColumns): + for col in c.columns_to_assign: + if not isinstance(c, Composite): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + + # remove object from the dictionary that will be passed + # as mapper(properties={...}) if it is not a MapperProperty + # (i.e. this currently means it's a MappedColumn) + mp_to_assign = c.mapper_property_to_assign + if mp_to_assign: + our_stuff[key] = mp_to_assign + else: + del our_stuff[key] + elif isinstance(c, Column): _undefer_column_name(key, c) name_to_prop_key[c.name].add(key) @@ -769,16 +890,12 @@ class _ClassScanMapperConfig(_MapperConfig): cls = self.cls tablename = self.tablename table_args = self.table_args - dict_ = self.dict_ + clsdict_view = self.clsdict_view declared_columns = self.declared_columns manager = attributes.manager_of_class(cls) - declared_columns = self.declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order - ) - - if "__table__" not in dict_ and table is None: + if "__table__" not in clsdict_view and table is None: if hasattr(cls, "__table_cls__"): table_cls = util.unbound_method_to_callable(cls.__table_cls__) else: @@ -796,11 +913,11 @@ class _ClassScanMapperConfig(_MapperConfig): else: args = table_args - autoload_with = dict_.get("__autoload_with__") + autoload_with = clsdict_view.get("__autoload_with__") if autoload_with: table_kw["autoload_with"] = autoload_with - autoload = dict_.get("__autoload__") + autoload = clsdict_view.get("__autoload__") if autoload: table_kw["autoload"] = True @@ -1095,18 +1212,21 @@ def _add_attribute(cls, key, value): _undefer_column_name(key, value) cls.__table__.append_column(value, replace_existing=True) cls.__mapper__.add_property(key, value) - elif isinstance(value, ColumnProperty): - for col in value.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - cls.__table__.append_column(col, replace_existing=True) - cls.__mapper__.add_property(key, value) + elif isinstance(value, _MapsColumns): + mp = value.mapper_property_to_assign + for col in value.columns_to_assign: + _undefer_column_name(key, col) + cls.__table__.append_column(col, replace_existing=True) + if not mp: + cls.__mapper__.add_property(key, col) + if mp: + cls.__mapper__.add_property(key, mp) elif isinstance(value, MapperProperty): cls.__mapper__.add_property(key, value) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() - value = SynonymProperty(value.key) + value = Synonym(value.key) cls.__mapper__.add_property(key, value) else: type.__setattr__(cls, key, value) @@ -1124,7 +1244,7 @@ def _del_attribute(cls, key): ): value = cls.__dict__[key] if isinstance( - value, (Column, ColumnProperty, MapperProperty, QueryableAttribute) + value, (Column, _MapsColumns, MapperProperty, QueryableAttribute) ): raise NotImplementedError( "Can't un-map individual mapped attributes on a mapped class." |