diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-02-13 20:37:12 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-02-13 20:37:12 +0000 |
commit | d6b3c82b0c329730bcaff42b4bb39dba83acb536 (patch) | |
tree | d6b7f744a35c8d89615eeb0504ee7a4193f95642 | |
parent | 260ade78a70d51378de9e7b9456bfe6218859b6c (diff) | |
parent | e545298e35ea9f126054b337e4b5ba01988b29f7 (diff) | |
download | sqlalchemy-d6b3c82b0c329730bcaff42b4bb39dba83acb536.tar.gz |
Merge "establish mypy / typing approach for v2.0" into main
149 files changed, 5657 insertions, 2158 deletions
diff --git a/doc/build/changelog/migration_20.rst b/doc/build/changelog/migration_20.rst index 406f782e5..380651a30 100644 --- a/doc/build/changelog/migration_20.rst +++ b/doc/build/changelog/migration_20.rst @@ -182,7 +182,7 @@ and pylance. Given a program as below:: from sqlalchemy.dialects.mysql import VARCHAR - type_ = String(255).with_variant(VARCHAR(255, charset='utf8mb4'), "mysql") + type_ = String(255).with_variant(VARCHAR(255, charset='utf8mb4'), "mysql", "mariadb") if typing.TYPE_CHECKING: reveal_type(type_) @@ -191,6 +191,9 @@ A type checker like pyright will now report the type as:: info: Type of "type_" is "String" +In addition, as illustrated above, multiple dialect names may be passed for +single type, in particular this is helpful for the pair of ``"mysql"`` and +``"mariadb"`` dialects which are considered separately as of SQLAlchemy 1.4. :ticket:`6980` diff --git a/doc/build/changelog/unreleased_20/6980.rst b/doc/build/changelog/unreleased_20/6980.rst index d83599c48..90cf74044 100644 --- a/doc/build/changelog/unreleased_20/6980.rst +++ b/doc/build/changelog/unreleased_20/6980.rst @@ -10,6 +10,10 @@ behaviors, maintaining the original type allows for clearer type checking and debugging. + :meth:`_sqltypes.TypeEngine.with_variant` also accepts multiple dialect + names per call as well, in particular this is helpful for related + backend names such as ``"mysql", "mariadb"``. + .. seealso:: :ref:`change_6980` diff --git a/doc/build/changelog/unreleased_20/composite_dataclass.rst b/doc/build/changelog/unreleased_20/composite_dataclass.rst new file mode 100644 index 000000000..a7312b0bd --- /dev/null +++ b/doc/build/changelog/unreleased_20/composite_dataclass.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: feature, orm + + The :func:`_orm.composite` mapping construct now supports automatic + resolution of values when used with a Python ``dataclass``; the + ``__composite_values__()`` method no longer needs to be implemented as this + method is derived from inspection of the dataclass. + + See the new documentation at :ref:`mapper_composite` for examples.
\ No newline at end of file diff --git a/doc/build/changelog/unreleased_20/decl_fks.rst b/doc/build/changelog/unreleased_20/decl_fks.rst new file mode 100644 index 000000000..94de46eac --- /dev/null +++ b/doc/build/changelog/unreleased_20/decl_fks.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: feature, orm + + Declarative mixins which use :class:`_schema.Column` objects that contain + :class:`_schema.ForeignKey` references no longer need to use + :func:`_orm.declared_attr` to achieve this mapping; the + :class:`_schema.ForeignKey` object is copied along with the + :class:`_schema.Column` itself when the column is applied to the declared + mapping.
\ No newline at end of file diff --git a/doc/build/changelog/unreleased_20/prop_name.rst b/doc/build/changelog/unreleased_20/prop_name.rst new file mode 100644 index 000000000..d085d0ddc --- /dev/null +++ b/doc/build/changelog/unreleased_20/prop_name.rst @@ -0,0 +1,19 @@ +.. change:: + :tags: change, orm + + To better accommodate explicit typing, the names of some ORM constructs + that are typically constructed internally, but nonetheless are sometimes + visible in messaging as well as typing, have been changed to more succinct + names which also match the name of their constructing function (with + different casing), in all cases maintaining aliases to the old names for + the forseeable future: + + * :class:`_orm.RelationshipProperty` becomes an alias for the primary name + :class:`_orm.Relationship`, which is constructed as always from the + :func:`_orm.relationship` function + * :class:`_orm.SynonymProperty` becomes an alias for the primary name + :class:`_orm.Synonym`, constructed as always from the + :func:`_orm.synonym` function + * :class:`_orm.CompositeProperty` becomes an alias for the primary name + :class:`_orm.Composite`, constructed as always from the + :func:`_orm.composite` function
\ No newline at end of file diff --git a/doc/build/core/selectable.rst b/doc/build/core/selectable.rst index bad7dc809..812f6f99a 100644 --- a/doc/build/core/selectable.rst +++ b/doc/build/core/selectable.rst @@ -156,21 +156,13 @@ Label Style Constants Constants used with the :meth:`_sql.GenerativeSelect.set_label_style` method. -.. autodata:: LABEL_STYLE_DISAMBIGUATE_ONLY +.. autoclass:: SelectLabelStyle + :members: -.. autodata:: LABEL_STYLE_NONE -.. autodata:: LABEL_STYLE_TABLENAME_PLUS_COL +.. seealso:: -.. data:: LABEL_STYLE_DEFAULT + :meth:`_sql.Select.set_label_style` - The default label style, refers to :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`. - - .. versionadded:: 1.4 - - .. seealso:: - - :meth:`_sql.Select.set_label_style` - - :meth:`_sql.Select.get_label_style` + :meth:`_sql.Select.get_label_style` diff --git a/doc/build/orm/composites.rst b/doc/build/orm/composites.rst index 0628f56ae..463bb70bc 100644 --- a/doc/build/orm/composites.rst +++ b/doc/build/orm/composites.rst @@ -5,6 +5,11 @@ Composite Column Types ====================== +.. note:: + + This documentation is not yet updated to illustrate the new + typing-annotation syntax or direct support for dataclasses. + Sets of columns can be associated with a single user-defined datatype. The ORM provides a single attribute which represents the group of columns using the class you provide. diff --git a/doc/build/orm/declarative_mixins.rst b/doc/build/orm/declarative_mixins.rst index 9bb4c782e..e78b96698 100644 --- a/doc/build/orm/declarative_mixins.rst +++ b/doc/build/orm/declarative_mixins.rst @@ -125,47 +125,61 @@ for each separate destination class. To accomplish this, the declarative extension creates a **copy** of each :class:`_schema.Column` object encountered on a class that is detected as a mixin. -This copy mechanism is limited to simple columns that have no foreign -keys, as a :class:`_schema.ForeignKey` itself contains references to columns -which can't be properly recreated at this level. For columns that -have foreign keys, as well as for the variety of mapper-level constructs -that require destination-explicit context, the -:class:`_orm.declared_attr` decorator is provided so that -patterns common to many classes can be defined as callables:: +This copy mechanism is limited to :class:`_schema.Column` and +:class:`_orm.MappedColumn` constructs. For :class:`_schema.Column` and +:class:`_orm.MappedColumn` constructs that contain references to +:class:`_schema.ForeignKey` constructs, the copy mechanism is limited to +foreign key references to remote tables only. + +.. versionchanged:: 2.0 The declarative API can now accommodate + :class:`_schema.Column` objects which refer to :class:`_schema.ForeignKey` + constraints to remote tables without the need to use the + :class:`_orm.declared_attr` function decorator. + +For the variety of mapper-level constructs that require destination-explicit +context, including self-referential foreign keys and constructs like +:func:`_orm.deferred`, :func:`_orm.relationship`, etc, the +:class:`_orm.declared_attr` decorator is provided so that patterns common to +many classes can be defined as callables:: from sqlalchemy.orm import declared_attr @declarative_mixin - class ReferenceAddressMixin: + class HasRelatedDataMixin: @declared_attr - def address_id(cls): - return Column(Integer, ForeignKey('address.id')) + def related_data(cls): + return deferred(Column(Text()) - class User(ReferenceAddressMixin, Base): + class User(HasRelatedDataMixin, Base): __tablename__ = 'user' id = Column(Integer, primary_key=True) -Where above, the ``address_id`` class-level callable is executed at the +Where above, the ``related_data`` class-level callable is executed at the point at which the ``User`` class is constructed, and the declarative -extension can use the resulting :class:`_schema.Column` object as returned by +extension can use the resulting :func`_orm.deferred` object as returned by the method without the need to copy it. -Columns generated by :class:`_orm.declared_attr` can also be -referenced by ``__mapper_args__`` to a limited degree, currently -by ``polymorphic_on`` and ``version_id_col``; the declarative extension -will resolve them at class construction time:: +For a self-referential foreign key on a mixin, the referenced +:class:`_schema.Column` object may be referenced in terms of the class directly +within the :class:`_orm.declared_attr`:: - @declarative_mixin - class MyMixin: - @declared_attr - def type_(cls): - return Column(String(50)) + class SelfReferentialMixin: + id = Column(Integer, primary_key=True) - __mapper_args__= {'polymorphic_on':type_} + @declared_attr + def parent_id(cls): + return Column(Integer, ForeignKey(cls.id)) + + class A(SelfReferentialMixin, Base): + __tablename__ = 'a' - class MyModel(MyMixin, Base): - __tablename__='test' - id = Column(Integer, primary_key=True) + + class B(SelfReferentialMixin, Base): + __tablename__ = 'b' + +Above, both classes ``A`` and ``B`` will contain columns ``id`` and +``parent_id``, where ``parent_id`` refers to the ``id`` column local to the +corresponding table ('a' or 'b'). .. _orm_declarative_mixins_relationships: @@ -182,9 +196,7 @@ reference a common target class via many-to-one:: @declarative_mixin class RefTargetMixin: - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) + target_id = Column('target_id', ForeignKey('target.id')) @declared_attr def target(cls): diff --git a/doc/build/orm/internals.rst b/doc/build/orm/internals.rst index 8520fd07c..05cf83b39 100644 --- a/doc/build/orm/internals.rst +++ b/doc/build/orm/internals.rst @@ -32,9 +32,10 @@ sections, are listed here. :ref:`maptojoin` - usage example -.. autoclass:: CompositeProperty +.. autoclass:: Composite :members: +.. autodata:: CompositeProperty .. autoclass:: AttributeEvent :members: @@ -62,6 +63,8 @@ sections, are listed here. .. autoclass:: Mapped +.. autoclass:: MappedColumn + .. autoclass:: MapperProperty :members: @@ -98,14 +101,18 @@ sections, are listed here. :members: :inherited-members: -.. autoclass:: RelationshipProperty +.. autoclass:: Relationship :members: :inherited-members: -.. autoclass:: SynonymProperty +.. autodata:: RelationshipProperty + +.. autoclass:: Synonym :members: :inherited-members: +.. autodata:: SynonymProperty + .. autoclass:: QueryContext :members: diff --git a/doc/build/orm/loading_relationships.rst b/doc/build/orm/loading_relationships.rst index 2b93bc84a..773409f02 100644 --- a/doc/build/orm/loading_relationships.rst +++ b/doc/build/orm/loading_relationships.rst @@ -1261,8 +1261,6 @@ Relationship Loader API .. autofunction:: defaultload -.. autofunction:: eagerload - .. autofunction:: immediateload .. autofunction:: joinedload diff --git a/doc/build/orm/relationship_api.rst b/doc/build/orm/relationship_api.rst index 2766c4020..ac584627f 100644 --- a/doc/build/orm/relationship_api.rst +++ b/doc/build/orm/relationship_api.rst @@ -7,8 +7,6 @@ Relationships API .. autofunction:: backref -.. autofunction:: relation - .. autofunction:: dynamic_loader .. autofunction:: foreign diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index c8ec1d825..eadb427d0 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -6,10 +6,56 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from . import util as _util +from .engine import AdaptedConnection as AdaptedConnection +from .engine import BaseCursorResult as BaseCursorResult +from .engine import BaseRow as BaseRow +from .engine import BindTyping as BindTyping +from .engine import BufferedColumnResultProxy as BufferedColumnResultProxy +from .engine import BufferedColumnRow as BufferedColumnRow +from .engine import BufferedRowResultProxy as BufferedRowResultProxy +from .engine import ChunkedIteratorResult as ChunkedIteratorResult +from .engine import Compiled as Compiled +from .engine import Connection as Connection from .engine import create_engine as create_engine from .engine import create_mock_engine as create_mock_engine +from .engine import CreateEnginePlugin as CreateEnginePlugin +from .engine import CursorResult as CursorResult +from .engine import Dialect as Dialect +from .engine import Engine as Engine from .engine import engine_from_config as engine_from_config +from .engine import ExceptionContext as ExceptionContext +from .engine import ExecutionContext as ExecutionContext +from .engine import FrozenResult as FrozenResult +from .engine import FullyBufferedResultProxy as FullyBufferedResultProxy +from .engine import Inspector as Inspector +from .engine import IteratorResult as IteratorResult +from .engine import make_url as make_url +from .engine import MappingResult as MappingResult +from .engine import MergedResult as MergedResult +from .engine import NestedTransaction as NestedTransaction +from .engine import Result as Result +from .engine import result_tuple as result_tuple +from .engine import ResultProxy as ResultProxy +from .engine import RootTransaction as RootTransaction +from .engine import Row as Row +from .engine import RowMapping as RowMapping +from .engine import ScalarResult as ScalarResult +from .engine import Transaction as Transaction +from .engine import TwoPhaseTransaction as TwoPhaseTransaction +from .engine import TypeCompiler as TypeCompiler +from .engine import URL as URL from .inspection import inspect as inspect +from .pool import AssertionPool as AssertionPool +from .pool import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .pool import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .pool import NullPool as NullPool +from .pool import Pool as Pool +from .pool import PoolProxiedConnection as PoolProxiedConnection +from .pool import QueuePool as QueuePool +from .pool import SingletonThreadPool as SingleonThreadPool +from .pool import StaticPool as StaticPool from .schema import BLANK_SCHEMA as BLANK_SCHEMA from .schema import CheckConstraint as CheckConstraint from .schema import Column as Column @@ -28,67 +74,139 @@ from .schema import PrimaryKeyConstraint as PrimaryKeyConstraint from .schema import Sequence as Sequence from .schema import Table as Table from .schema import UniqueConstraint as UniqueConstraint -from .sql import alias as alias -from .sql import all_ as all_ -from .sql import and_ as and_ -from .sql import any_ as any_ -from .sql import asc as asc -from .sql import between as between -from .sql import bindparam as bindparam -from .sql import case as case -from .sql import cast as cast -from .sql import collate as collate -from .sql import column as column -from .sql import delete as delete -from .sql import desc as desc -from .sql import distinct as distinct -from .sql import except_ as except_ -from .sql import except_all as except_all -from .sql import exists as exists -from .sql import extract as extract -from .sql import false as false -from .sql import func as func -from .sql import funcfilter as funcfilter -from .sql import insert as insert -from .sql import intersect as intersect -from .sql import intersect_all as intersect_all -from .sql import join as join -from .sql import label as label -from .sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT -from .sql import ( +from .sql import SelectLabelStyle as SelectLabelStyle +from .sql.expression import Alias as Alias +from .sql.expression import alias as alias +from .sql.expression import AliasedReturnsRows as AliasedReturnsRows +from .sql.expression import all_ as all_ +from .sql.expression import and_ as and_ +from .sql.expression import any_ as any_ +from .sql.expression import asc as asc +from .sql.expression import between as between +from .sql.expression import BinaryExpression as BinaryExpression +from .sql.expression import bindparam as bindparam +from .sql.expression import BindParameter as BindParameter +from .sql.expression import BooleanClauseList as BooleanClauseList +from .sql.expression import CacheKey as CacheKey +from .sql.expression import Case as Case +from .sql.expression import case as case +from .sql.expression import Cast as Cast +from .sql.expression import cast as cast +from .sql.expression import ClauseElement as ClauseElement +from .sql.expression import ClauseList as ClauseList +from .sql.expression import collate as collate +from .sql.expression import CollectionAggregate as CollectionAggregate +from .sql.expression import column as column +from .sql.expression import ColumnClause as ColumnClause +from .sql.expression import ColumnCollection as ColumnCollection +from .sql.expression import ColumnElement as ColumnElement +from .sql.expression import ColumnOperators as ColumnOperators +from .sql.expression import CompoundSelect as CompoundSelect +from .sql.expression import CTE as CTE +from .sql.expression import cte as cte +from .sql.expression import custom_op as custom_op +from .sql.expression import Delete as Delete +from .sql.expression import delete as delete +from .sql.expression import desc as desc +from .sql.expression import distinct as distinct +from .sql.expression import except_ as except_ +from .sql.expression import except_all as except_all +from .sql.expression import Executable as Executable +from .sql.expression import Exists as Exists +from .sql.expression import exists as exists +from .sql.expression import Extract as Extract +from .sql.expression import extract as extract +from .sql.expression import false as false +from .sql.expression import False_ as False_ +from .sql.expression import FromClause as FromClause +from .sql.expression import FromGrouping as FromGrouping +from .sql.expression import func as func +from .sql.expression import funcfilter as funcfilter +from .sql.expression import Function as Function +from .sql.expression import FunctionElement as FunctionElement +from .sql.expression import FunctionFilter as FunctionFilter +from .sql.expression import GenerativeSelect as GenerativeSelect +from .sql.expression import Grouping as Grouping +from .sql.expression import HasCTE as HasCTE +from .sql.expression import HasPrefixes as HasPrefixes +from .sql.expression import HasSuffixes as HasSuffixes +from .sql.expression import Insert as Insert +from .sql.expression import insert as insert +from .sql.expression import intersect as intersect +from .sql.expression import intersect_all as intersect_all +from .sql.expression import Join as Join +from .sql.expression import join as join +from .sql.expression import Label as Label +from .sql.expression import label as label +from .sql.expression import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from .sql.expression import ( LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, ) -from .sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE -from .sql import ( +from .sql.expression import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from .sql.expression import ( LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, ) -from .sql import lambda_stmt as lambda_stmt -from .sql import lateral as lateral -from .sql import literal as literal -from .sql import literal_column as literal_column -from .sql import modifier as modifier -from .sql import not_ as not_ -from .sql import null as null -from .sql import nulls_first as nulls_first -from .sql import nulls_last as nulls_last -from .sql import nullsfirst as nullsfirst -from .sql import nullslast as nullslast -from .sql import or_ as or_ -from .sql import outerjoin as outerjoin -from .sql import outparam as outparam -from .sql import over as over -from .sql import select as select -from .sql import table as table -from .sql import tablesample as tablesample -from .sql import text as text -from .sql import true as true -from .sql import tuple_ as tuple_ -from .sql import type_coerce as type_coerce -from .sql import union as union -from .sql import union_all as union_all -from .sql import update as update -from .sql import values as values -from .sql import within_group as within_group +from .sql.expression import lambda_stmt as lambda_stmt +from .sql.expression import LambdaElement as LambdaElement +from .sql.expression import Lateral as Lateral +from .sql.expression import lateral as lateral +from .sql.expression import literal as literal +from .sql.expression import literal_column as literal_column +from .sql.expression import modifier as modifier +from .sql.expression import not_ as not_ +from .sql.expression import Null as Null +from .sql.expression import null as null +from .sql.expression import nulls_first as nulls_first +from .sql.expression import nulls_last as nulls_last +from .sql.expression import Operators as Operators +from .sql.expression import or_ as or_ +from .sql.expression import outerjoin as outerjoin +from .sql.expression import outparam as outparam +from .sql.expression import Over as Over +from .sql.expression import over as over +from .sql.expression import quoted_name as quoted_name +from .sql.expression import ReleaseSavepointClause as ReleaseSavepointClause +from .sql.expression import ReturnsRows as ReturnsRows +from .sql.expression import ( + RollbackToSavepointClause as RollbackToSavepointClause, +) +from .sql.expression import SavepointClause as SavepointClause +from .sql.expression import ScalarSelect as ScalarSelect +from .sql.expression import Select as Select +from .sql.expression import select as select +from .sql.expression import Selectable as Selectable +from .sql.expression import SelectBase as SelectBase +from .sql.expression import StatementLambdaElement as StatementLambdaElement +from .sql.expression import Subquery as Subquery +from .sql.expression import table as table +from .sql.expression import TableClause as TableClause +from .sql.expression import TableSample as TableSample +from .sql.expression import tablesample as tablesample +from .sql.expression import TableValuedAlias as TableValuedAlias +from .sql.expression import text as text +from .sql.expression import TextAsFrom as TextAsFrom +from .sql.expression import TextClause as TextClause +from .sql.expression import TextualSelect as TextualSelect +from .sql.expression import true as true +from .sql.expression import True_ as True_ +from .sql.expression import Tuple as Tuple +from .sql.expression import tuple_ as tuple_ +from .sql.expression import type_coerce as type_coerce +from .sql.expression import TypeClause as TypeClause +from .sql.expression import TypeCoerce as TypeCoerce +from .sql.expression import typing as typing +from .sql.expression import UnaryExpression as UnaryExpression +from .sql.expression import union as union +from .sql.expression import union_all as union_all +from .sql.expression import Update as Update +from .sql.expression import update as update +from .sql.expression import UpdateBase as UpdateBase +from .sql.expression import Values as Values +from .sql.expression import values as values +from .sql.expression import ValuesBase as ValuesBase +from .sql.expression import Visitable as Visitable +from .sql.expression import within_group as within_group +from .sql.expression import WithinGroup as WithinGroup from .types import ARRAY as ARRAY from .types import BIGINT as BIGINT from .types import BigInteger as BigInteger @@ -133,7 +251,6 @@ from .types import UnicodeText as UnicodeText from .types import VARBINARY as VARBINARY from .types import VARCHAR as VARCHAR - __version__ = "2.0.0b1" diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index e934f9f89..c6bc4b6aa 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -15,45 +15,45 @@ constructor ``create_engine()``. """ -from . import events -from . import util -from .base import Connection -from .base import Engine -from .base import NestedTransaction -from .base import RootTransaction -from .base import Transaction -from .base import TwoPhaseTransaction -from .create import create_engine -from .create import engine_from_config -from .cursor import BaseCursorResult -from .cursor import BufferedColumnResultProxy -from .cursor import BufferedColumnRow -from .cursor import BufferedRowResultProxy -from .cursor import CursorResult -from .cursor import FullyBufferedResultProxy -from .cursor import ResultProxy -from .interfaces import AdaptedConnection -from .interfaces import BindTyping -from .interfaces import Compiled -from .interfaces import CreateEnginePlugin -from .interfaces import Dialect -from .interfaces import ExceptionContext -from .interfaces import ExecutionContext -from .interfaces import TypeCompiler -from .mock import create_mock_engine -from .reflection import Inspector -from .result import ChunkedIteratorResult -from .result import FrozenResult -from .result import IteratorResult -from .result import MappingResult -from .result import MergedResult -from .result import Result -from .result import result_tuple -from .result import ScalarResult -from .row import BaseRow -from .row import Row -from .row import RowMapping -from .url import make_url -from .url import URL -from .util import connection_memoize -from ..sql import ddl +from . import events as events +from . import util as util +from .base import Connection as Connection +from .base import Engine as Engine +from .base import NestedTransaction as NestedTransaction +from .base import RootTransaction as RootTransaction +from .base import Transaction as Transaction +from .base import TwoPhaseTransaction as TwoPhaseTransaction +from .create import create_engine as create_engine +from .create import engine_from_config as engine_from_config +from .cursor import BaseCursorResult as BaseCursorResult +from .cursor import BufferedColumnResultProxy as BufferedColumnResultProxy +from .cursor import BufferedColumnRow as BufferedColumnRow +from .cursor import BufferedRowResultProxy as BufferedRowResultProxy +from .cursor import CursorResult as CursorResult +from .cursor import FullyBufferedResultProxy as FullyBufferedResultProxy +from .cursor import ResultProxy as ResultProxy +from .interfaces import AdaptedConnection as AdaptedConnection +from .interfaces import BindTyping as BindTyping +from .interfaces import Compiled as Compiled +from .interfaces import CreateEnginePlugin as CreateEnginePlugin +from .interfaces import Dialect as Dialect +from .interfaces import ExceptionContext as ExceptionContext +from .interfaces import ExecutionContext as ExecutionContext +from .interfaces import TypeCompiler as TypeCompiler +from .mock import create_mock_engine as create_mock_engine +from .reflection import Inspector as Inspector +from .result import ChunkedIteratorResult as ChunkedIteratorResult +from .result import FrozenResult as FrozenResult +from .result import IteratorResult as IteratorResult +from .result import MappingResult as MappingResult +from .result import MergedResult as MergedResult +from .result import Result as Result +from .result import result_tuple as result_tuple +from .result import ScalarResult as ScalarResult +from .row import BaseRow as BaseRow +from .row import Row as Row +from .row import RowMapping as RowMapping +from .url import make_url as make_url +from .url import URL as URL +from .util import connection_memoize as connection_memoize +from ..sql import ddl as ddl diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 6fb827989..2f8ce17df 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -6,6 +6,7 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from typing import Any +from typing import Union from . import base from . import url as _url @@ -41,7 +42,7 @@ from ..sql import compiler "is deprecated and will be removed in a future release. ", ), ) -def create_engine(url: "_url.URL", **kwargs: Any) -> "base.Engine": +def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": """Create a new :class:`_engine.Engine` instance. The standard calling form is to send the :ref:`URL <database_urls>` as the diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index df7a53ab7..882392e9c 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -24,11 +24,13 @@ methods such as get_table_names, get_columns, etc. use the key 'name'. So for most return values, each record will have a 'name' attribute.. """ - import contextlib +from typing import List +from typing import Optional from .base import Connection from .base import Engine +from .interfaces import ReflectedColumn from .. import exc from .. import inspection from .. import sql @@ -433,7 +435,9 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, view_name, schema, info_cache=self.info_cache ) - def get_columns(self, table_name, schema=None, **kw): + def get_columns( + self, table_name: str, schema: Optional[str] = None, **kw + ) -> List[ReflectedColumn]: """Return information about columns in `table_name`. Given a string `table_name` and an optional string `schema`, return diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index e6a826c64..d5119907e 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -361,7 +361,7 @@ class AssociationProxyInstance: prop = orm.class_mapper(owning_class).get_property(target_collection) # this was never asserted before but this should be made clear. - if not isinstance(prop, orm.RelationshipProperty): + if not isinstance(prop, orm.Relationship): raise NotImplementedError( "association proxy to a non-relationship " "intermediary is not supported" @@ -717,8 +717,8 @@ class AssociationProxyInstance: """Produce a proxied 'any' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any` - and/or :meth:`.RelationshipProperty.Comparator.has` + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` operators of the underlying proxied attributes. """ @@ -737,8 +737,8 @@ class AssociationProxyInstance: """Produce a proxied 'has' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any` - and/or :meth:`.RelationshipProperty.Comparator.has` + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` operators of the underlying proxied attributes. """ @@ -859,9 +859,9 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): """Produce a proxied 'contains' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any`, - :meth:`.RelationshipProperty.Comparator.has`, - and/or :meth:`.RelationshipProperty.Comparator.contains` + using the :meth:`.Relationship.Comparator.any`, + :meth:`.Relationship.Comparator.has`, + and/or :meth:`.Relationship.Comparator.contains` operators of the underlying proxied attributes. """ diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py index 5aff4dfe2..470ff6ad8 100644 --- a/lib/sqlalchemy/ext/declarative/extensions.py +++ b/lib/sqlalchemy/ext/declarative/extensions.py @@ -378,7 +378,7 @@ class DeferredReflection: metadata = mapper.class_.metadata for rel in mapper._props.values(): if ( - isinstance(rel, relationships.RelationshipProperty) + isinstance(rel, relationships.Relationship) and rel.secondary is not None ): if isinstance(rel.secondary, Table): diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 99be194cd..4e244b5b9 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -36,6 +36,7 @@ from mypy.types import UnionType from . import infer from . import util +from .names import expr_to_mapped_constructor from .names import NAMED_TYPE_SQLA_MAPPED @@ -117,6 +118,7 @@ def re_apply_declarative_assignments( ): left_node = stmt.lvalues[0].node + python_type_for_type = mapped_attr_lookup[ stmt.lvalues[0].name ].type @@ -142,7 +144,7 @@ def re_apply_declarative_assignments( ) ): - python_type_for_type = ( + new_python_type_for_type = ( infer.infer_type_from_right_hand_nameexpr( api, stmt, @@ -152,19 +154,27 @@ def re_apply_declarative_assignments( ) ) - if python_type_for_type is None or isinstance( - python_type_for_type, UnboundType + if new_python_type_for_type is not None and not isinstance( + new_python_type_for_type, UnboundType ): - continue + python_type_for_type = new_python_type_for_type - # update the SQLAlchemyAttribute with the better information - mapped_attr_lookup[ - stmt.lvalues[0].name - ].type = python_type_for_type + # update the SQLAlchemyAttribute with the better + # information + mapped_attr_lookup[ + stmt.lvalues[0].name + ].type = python_type_for_type - update_cls_metadata = True + update_cls_metadata = True - if python_type_for_type is not None: + # for some reason if you have a Mapped type explicitly annotated, + # and here you set it again, mypy forgets how to do descriptors. + # no idea. 100% feeling around in the dark to see what sticks + if ( + not isinstance(left_node.type, Instance) + or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED + ): + assert python_type_for_type is not None left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, [python_type_for_type] ) @@ -202,6 +212,7 @@ def apply_type_to_mapped_statement( assert isinstance(left_node, Var) if left_hand_explicit_type is not None: + lvalue.is_inferred_def = False left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] ) @@ -224,7 +235,7 @@ def apply_type_to_mapped_statement( # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>) # the original right-hand side is maintained so it gets type checked # internally - stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue) + stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue) def add_additional_orm_attributes( diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index c33c30e25..bd6c6f41e 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -337,7 +337,7 @@ def _scan_declarative_decorator_stmt( # <attr> : Mapped[<typ>] = # _sa_Mapped._empty_constructor(lambda: <function body>) # the function body is maintained so it gets type checked internally - rvalue = util.expr_to_mapped_constructor( + rvalue = names.expr_to_mapped_constructor( LambdaExpr(stmt.func.arguments, stmt.func.body) ) diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index 3cd946e04..6a5e99e48 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -42,11 +42,13 @@ def infer_type_from_right_hand_nameexpr( left_hand_explicit_type: Optional[ProperType], infer_from_right_side: RefExpr, ) -> Optional[ProperType]: - type_id = names.type_id_for_callee(infer_from_right_side) - if type_id is None: return None + elif type_id is names.MAPPED: + python_type_for_type = _infer_type_from_mapped( + api, stmt, node, left_hand_explicit_type, infer_from_right_side + ) elif type_id is names.COLUMN: python_type_for_type = _infer_type_from_decl_column( api, stmt, node, left_hand_explicit_type @@ -245,7 +247,7 @@ def _infer_type_from_decl_composite_property( node: Var, left_hand_explicit_type: Optional[ProperType], ) -> Optional[ProperType]: - """Infer the type of mapping from a CompositeProperty.""" + """Infer the type of mapping from a Composite.""" assert isinstance(stmt.rvalue, CallExpr) target_cls_arg = stmt.rvalue.args[0] @@ -271,6 +273,38 @@ def _infer_type_from_decl_composite_property( return python_type_for_type +def _infer_type_from_mapped( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + """Infer the type of mapping from a right side expression + that returns Mapped. + + + """ + assert isinstance(stmt.rvalue, CallExpr) + + # (Pdb) print(stmt.rvalue.callee) + # NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501 + # (Pdb) stmt.rvalue.callee.node + # <mypy.nodes.FuncDef object at 0x7f8d92fb5940> + # (Pdb) stmt.rvalue.callee.node.type + # def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501 + # sqlalchemy.orm.base.Mapped[_T`-1] + # the_mapped_type = stmt.rvalue.callee.node.type.ret_type + + # TODO: look at generic ref and either use that, + # or reconcile w/ what's present, etc. + the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + def _infer_type_from_decl_column_property( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index b6f911979..ad4449e5b 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -12,11 +12,14 @@ from typing import Set from typing import Tuple from typing import Union +from mypy.nodes import ARG_POS +from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import Expression from mypy.nodes import FuncDef from mypy.nodes import MemberExpr from mypy.nodes import NameExpr +from mypy.nodes import OverloadedFuncDef from mypy.nodes import SymbolNode from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo @@ -51,7 +54,7 @@ QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore NAMED_TYPE_BUILTINS_OBJECT = "builtins.object" NAMED_TYPE_BUILTINS_STR = "builtins.str" NAMED_TYPE_BUILTINS_LIST = "builtins.list" -NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped" +NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" _lookup: Dict[str, Tuple[int, Set[str]]] = { "Column": ( @@ -61,11 +64,11 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.sql.Column", }, ), - "RelationshipProperty": ( + "Relationship": ( RELATIONSHIP, { - "sqlalchemy.orm.relationships.RelationshipProperty", - "sqlalchemy.orm.RelationshipProperty", + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.Relationship", }, ), "registry": ( @@ -82,18 +85,18 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.orm.ColumnProperty", }, ), - "SynonymProperty": ( + "Synonym": ( SYNONYM_PROPERTY, { - "sqlalchemy.orm.descriptor_props.SynonymProperty", - "sqlalchemy.orm.SynonymProperty", + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", }, ), - "CompositeProperty": ( + "Composite": ( COMPOSITE_PROPERTY, { - "sqlalchemy.orm.descriptor_props.CompositeProperty", - "sqlalchemy.orm.CompositeProperty", + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", }, ), "MapperProperty": ( @@ -159,7 +162,10 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { ), "query_expression": ( QUERY_EXPRESSION, - {"sqlalchemy.orm.query_expression"}, + { + "sqlalchemy.orm.query_expression", + "sqlalchemy.orm._orm_constructors.query_expression", + }, ), } @@ -209,7 +215,19 @@ def type_id_for_unbound_type( def type_id_for_callee(callee: Expression) -> Optional[int]: if isinstance(callee, (MemberExpr, NameExpr)): - if isinstance(callee.node, FuncDef): + if isinstance(callee.node, OverloadedFuncDef): + if ( + callee.node.impl + and callee.node.impl.type + and isinstance(callee.node.impl.type, CallableType) + ): + ret_type = get_proper_type(callee.node.impl.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, FuncDef): if callee.node.type and isinstance(callee.node.type, CallableType): ret_type = get_proper_type(callee.node.type.ret_type) @@ -251,3 +269,15 @@ def type_id_for_fullname(fullname: str) -> Optional[int]: return type_id else: return None + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index 0a21feb51..c9520fef3 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -40,6 +40,19 @@ from . import decl_class from . import names from . import util +try: + import sqlalchemy_stubs # noqa +except ImportError: + pass +else: + import sqlalchemy + + raise ImportError( + f"The SQLAlchemy mypy plugin in SQLAlchemy " + f"{sqlalchemy.__version__} does not work with sqlalchemy-stubs or " + "sqlalchemy2-stubs installed" + ) + class SQLAlchemyPlugin(Plugin): def get_dynamic_class_hook( diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index fa42074c3..741772eac 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -10,24 +10,27 @@ from typing import Type as TypingType from typing import TypeVar from typing import Union -from mypy.nodes import ARG_POS from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context from mypy.nodes import Expression +from mypy.nodes import FuncDef from mypy.nodes import IfStmt from mypy.nodes import JsonDict from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import Statement from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo from mypy.plugin import ClassDefContext from mypy.plugin import DynamicClassDefContext from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import deserialize_and_fixup_type from mypy.typeops import map_type_from_supertype +from mypy.types import CallableType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneType from mypy.types import Type @@ -231,6 +234,25 @@ def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: yield stmt +def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return ret_type + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return target_type + elif isinstance(callee.node, TypeInfo): + return callee.node + return None + + def unbound_to_instance( api: SemanticAnalyzerPluginInterface, typ: Type ) -> Type: @@ -290,15 +312,3 @@ def info_for_cls( return sym.node return cls.info - - -def expr_to_mapped_constructor(expr: Expression) -> CallExpr: - column_descriptor = NameExpr("__sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" - member_expr = MemberExpr(column_descriptor, "_empty_constructor") - return CallExpr( - member_expr, - [expr], - [ARG_POS], - ["arg1"], - ) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 5a327d1a5..5384851b1 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -119,14 +119,28 @@ start numbering at 1 or some other integer, provide ``count_from=1``. """ +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import TypeVar + from ..orm.collections import collection from ..orm.collections import collection_adapter +_T = TypeVar("_T") +OrderingFunc = Callable[[int, Sequence[_T]], int] + __all__ = ["ordering_list"] -def ordering_list(attr, count_from=None, **kw): +def ordering_list( + attr: str, + count_from: Optional[int] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, +) -> Callable[[], "OrderingList"]: """Prepares an :class:`OrderingList` factory for use in mapper definitions. Returns an object suitable for use as an argument to a Mapper @@ -157,7 +171,11 @@ def ordering_list(attr, count_from=None, **kw): """ - kw = _unsugar_count_from(count_from=count_from, **kw) + kw = _unsugar_count_from( + count_from=count_from, + ordering_func=ordering_func, + reorder_on_append=reorder_on_append, + ) return lambda: OrderingList(attr, **kw) @@ -207,7 +225,7 @@ def _unsugar_count_from(**kw): return kw -class OrderingList(list): +class OrderingList(List[_T]): """A custom list that manages position information for its children. The :class:`.OrderingList` object is normally set up using the @@ -216,8 +234,15 @@ class OrderingList(list): """ + ordering_attr: str + ordering_func: OrderingFunc + reorder_on_append: bool + def __init__( - self, ordering_attr=None, ordering_func=None, reorder_on_append=False + self, + ordering_attr: Optional[str] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, ): """A custom list that manages position information for its children. @@ -282,7 +307,7 @@ class OrderingList(list): def _set_order_value(self, entity, value): setattr(entity, self.ordering_attr, value) - def reorder(self): + def reorder(self) -> None: """Synchronize ordering for the entire collection. Sweeps through the list and ensures that each object has accurate diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 885163ecb..c6a8b6ea7 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -74,6 +74,8 @@ def class_logger(cls: Type[_IT]) -> Type[_IT]: class Identified: + __slots__ = () + logging_name: Optional[str] = None logger: Union[logging.Logger, "InstanceLogger"] @@ -116,6 +118,8 @@ class InstanceLogger: _echo: _EchoFlagType + __slots__ = ("echo", "logger") + def __init__(self, echo: _EchoFlagType, name: str): self.echo = echo self.logger = logging.getLogger(name) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 55f2f3100..bbed93310 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -17,19 +17,27 @@ from . import exc as exc from . import mapper as mapperlib from . import strategy_options as strategy_options from ._orm_constructors import _mapper_fn as mapper +from ._orm_constructors import aliased as aliased from ._orm_constructors import backref as backref from ._orm_constructors import clear_mappers as clear_mappers from ._orm_constructors import column_property as column_property from ._orm_constructors import composite as composite +from ._orm_constructors import CompositeProperty as CompositeProperty from ._orm_constructors import contains_alias as contains_alias from ._orm_constructors import create_session as create_session from ._orm_constructors import deferred as deferred from ._orm_constructors import dynamic_loader as dynamic_loader +from ._orm_constructors import join as join from ._orm_constructors import mapped_column as mapped_column +from ._orm_constructors import MappedColumn as MappedColumn +from ._orm_constructors import outerjoin as outerjoin from ._orm_constructors import query_expression as query_expression from ._orm_constructors import relationship as relationship +from ._orm_constructors import RelationshipProperty as RelationshipProperty from ._orm_constructors import synonym as synonym +from ._orm_constructors import SynonymProperty as SynonymProperty from ._orm_constructors import with_loader_criteria as with_loader_criteria +from ._orm_constructors import with_polymorphic as with_polymorphic from .attributes import AttributeEvent as AttributeEvent from .attributes import InstrumentedAttribute as InstrumentedAttribute from .attributes import QueryableAttribute as QueryableAttribute @@ -46,8 +54,8 @@ from .decl_api import declared_attr as declared_attr from .decl_api import has_inherited_table as has_inherited_table from .decl_api import registry as registry from .decl_api import synonym_for as synonym_for -from .descriptor_props import CompositeProperty as CompositeProperty -from .descriptor_props import SynonymProperty as SynonymProperty +from .descriptor_props import Composite as Composite +from .descriptor_props import Synonym as Synonym from .dynamic import AppenderQuery as AppenderQuery from .events import AttributeEvents as AttributeEvents from .events import InstanceEvents as InstanceEvents @@ -81,7 +89,7 @@ from .query import AliasOption as AliasOption from .query import FromStatement as FromStatement from .query import Query as Query from .relationships import foreign as foreign -from .relationships import RelationshipProperty as RelationshipProperty +from .relationships import Relationship as Relationship from .relationships import remote as remote from .scoping import scoped_session as scoped_session from .session import close_all_sessions as close_all_sessions @@ -111,17 +119,13 @@ from .strategy_options import undefer as undefer from .strategy_options import undefer_group as undefer_group from .strategy_options import with_expression as with_expression from .unitofwork import UOWTransaction as UOWTransaction -from .util import aliased as aliased from .util import Bundle as Bundle from .util import CascadeOptions as CascadeOptions -from .util import join as join from .util import LoaderCriteriaOption as LoaderCriteriaOption from .util import object_mapper as object_mapper -from .util import outerjoin as outerjoin from .util import polymorphic_union as polymorphic_union from .util import was_deleted as was_deleted from .util import with_parent as with_parent -from .util import with_polymorphic as with_polymorphic from .. import util as _sa_util diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 80607670e..a1f1faa05 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -7,35 +7,52 @@ import typing from typing import Any -from typing import Callable from typing import Collection +from typing import List +from typing import Mapping from typing import Optional from typing import overload +from typing import Set from typing import Type from typing import Union from . import mapper as mapperlib from .base import Mapped -from .descriptor_props import CompositeProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .mapper import Mapper from .properties import ColumnProperty +from .properties import MappedColumn from .query import AliasOption -from .relationships import RelationshipProperty +from .relationships import _RelationshipArgumentType +from .relationships import Relationship from .session import Session +from .util import _ORMJoin +from .util import AliasedClass +from .util import AliasedInsp from .util import LoaderCriteriaOption from .. import sql from .. import util from ..exc import InvalidRequestError -from ..sql.schema import Column -from ..sql.schema import SchemaEventTarget +from ..sql.base import SchemaEventTarget +from ..sql.selectable import Alias +from ..sql.selectable import FromClause from ..sql.type_api import TypeEngine from ..util.typing import Literal - -_RC = typing.TypeVar("_RC") _T = typing.TypeVar("_T") +CompositeProperty = Composite +"""Alias for :class:`_orm.Composite`.""" + +RelationshipProperty = Relationship +"""Alias for :class:`_orm.Relationship`.""" + +SynonymProperty = Synonym +"""Alias for :class:`_orm.Synonym`.""" + + @util.deprecated( "1.4", "The :class:`.AliasOption` object is not necessary " @@ -51,35 +68,45 @@ def contains_alias(alias) -> "AliasOption": return AliasOption(alias) +# see test/ext/mypy/plain_files/mapped_column.py for mapped column +# typing tests + + @overload def mapped_column( + __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., + nullable: Literal[None] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped": +) -> "MappedColumn[Any]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[None] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[Any]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[True] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[Optional[_T]]": ... @@ -87,45 +114,48 @@ def mapped_column( def mapped_column( __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[False]] = ..., - primary_key: Literal[True] = True, + nullable: Literal[True] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[Optional[_T]]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, nullable: Literal[False] = ..., - primary_key: bool = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( - __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[False] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( - __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: bool = ..., + primary_key: Literal[True] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[_T]": ... @@ -134,55 +164,209 @@ def mapped_column( __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[False]] = ..., - primary_key: Literal[True] = True, + nullable: bool = ..., + primary_key: Literal[True] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( __name: str, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Literal[False] = ..., + nullable: bool = ..., primary_key: bool = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[Any]": ... -def mapped_column(*args, **kw) -> "Mapped": - """construct a new ORM-mapped :class:`_schema.Column` construct. +@overload +def mapped_column( + *args: SchemaEventTarget, + nullable: bool = ..., + primary_key: bool = ..., + deferred: bool = ..., + **kw: Any, +) -> "MappedColumn[Any]": + ... + + +def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": + r"""construct a new ORM-mapped :class:`_schema.Column` construct. + + The :func:`_orm.mapped_column` function provides an ORM-aware and + Python-typing-compatible construct which is used with + :ref:`declarative <orm_declarative_mapping>` mappings to indicate an + attribute that's mapped to a Core :class:`_schema.Column` object. It + provides the equivalent feature as mapping an attribute to a + :class:`_schema.Column` object directly when using declarative. + + .. versionadded:: 2.0 - The :func:`_orm.mapped_column` function is shorthand for the construction - of a Core :class:`_schema.Column` object delivered within a - :func:`_orm.column_property` construct, which provides for consistent - typing information to be delivered to the class so that it works under - static type checkers such as mypy and delivers useful information in - IDE related type checkers such as pylance. The function can be used - in declarative mappings anywhere that :class:`_schema.Column` is normally - used:: + :func:`_orm.mapped_column` is normally used with explicit typing along with + the :class:`_orm.Mapped` mapped attribute type, where it can derive the SQL + type and nullability for the column automatically, such as:: + from typing import Optional + + from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column class User(Base): __tablename__ = 'user' - id = mapped_column(Integer) - name = mapped_column(String) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + options: Mapped[Optional[str]] = mapped_column() + + In the above example, the ``int`` and ``str`` types are inferred by the + Declarative mapping system to indicate use of the :class:`_types.Integer` + and :class:`_types.String` datatypes, and the presence of ``Optional`` or + not indicates whether or not each non-primary-key column is to be + ``nullable=True`` or ``nullable=False``. + + The above example, when interpreted within a Declarative class, will result + in a table named ``"user"`` which is equivalent to the following:: + + from sqlalchemy import Integer + from sqlalchemy import String + from sqlalchemy import Table + + Table( + 'user', + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String, nullable=False), + Column("options", String, nullable=True), + ) + The :func:`_orm.mapped_column` construct accepts the same arguments as + that of :class:`_schema.Column` directly, including optional "name" + and "type" fields, so the above mapping can be stated more explicitly + as:: - .. versionadded:: 2.0 + from typing import Optional + + from sqlalchemy import Integer + from sqlalchemy import String + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column("id", Integer, primary_key=True) + name: Mapped[str] = mapped_column("name", String, nullable=False) + options: Mapped[Optional[str]] = mapped_column( + "name", String, nullable=True + ) + + Arguments passed to :func:`_orm.mapped_column` always supersede those which + would be derived from the type annotation and/or attribute name. To state + the above mapping with more specific datatypes for ``id`` and ``options``, + and a different column name for ``name``, looks like:: + + from sqlalchemy import BigInteger + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column("id", BigInteger, primary_key=True) + name: Mapped[str] = mapped_column("user_name") + options: Mapped[Optional[str]] = mapped_column(String(50)) + + Where again, datatypes and nullable parameters that can be automatically + derived may be omitted. + + The datatypes passed to :class:`_orm.Mapped` are mapped to SQL + :class:`_types.TypeEngine` types with the following default mapping:: + + _type_map = { + int: Integer(), + float: Float(), + bool: Boolean(), + decimal.Decimal: Numeric(), + dt.date: Date(), + dt.datetime: DateTime(), + dt.time: Time(), + dt.timedelta: Interval(), + util.NoneType: NULLTYPE, + bytes: LargeBinary(), + str: String(), + } + + The above mapping may be expanded to include any combination of Python + datatypes to SQL types by using the + :paramref:`_orm.registry.type_annotation_map` parameter to + :class:`_orm.registry`, or as the attribute ``type_annotation_map`` upon + the :class:`_orm.DeclarativeBase` base class. + + Finally, :func:`_orm.mapped_column` is implicitly used by the Declarative + mapping system for any :class:`_orm.Mapped` annotation that has no + attribute value set up. This is much in the way that Python dataclasses + allow the ``field()`` construct to be optional, only needed when additional + parameters should be associated with the field. Using this functionality, + our original mapping can be stated even more succinctly as:: + + from typing import Optional + + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + options: Mapped[Optional[str]] + + Above, the ``name`` and ``options`` columns will be evaluated as + ``Column("name", String, nullable=False)`` and + ``Column("options", String, nullable=True)``, respectively. + + :param __name: String name to give to the :class:`_schema.Column`. This + is an optional, positional only argument that if present must be the + first positional argument passed. If omitted, the attribute name to + which the :func:`_orm.mapped_column` is mapped will be used as the SQL + column name. + :param __type: :class:`_types.TypeEngine` type or instance which will + indicate the datatype to be associated with the :class:`_schema.Column`. + This is an optional, positional-only argument that if present must + immediately follow the ``__name`` parameter if present also, or otherwise + be the first positional parameter. If omitted, the ultimate type for + the column may be derived either from the annotated type, or if a + :class:`_schema.ForeignKey` is present, from the datatype of the + referenced column. + :param \*args: Additional positional arguments include constructs such + as :class:`_schema.ForeignKey`, :class:`_schema.CheckConstraint`, + and :class:`_schema.Identity`, which are passed through to the constructed + :class:`_schema.Column`. + :param nullable: Optional bool, whether the column should be "NULL" or + "NOT NULL". If omitted, the nullability is derived from the type + annotation based on whether or not ``typing.Optional`` is present. + ``nullable`` defaults to ``True`` otherwise for non-primary key columns, + and ``False`` or primary key columns. + :param primary_key: optional bool, indicates the :class:`_schema.Column` + would be part of the table's primary key or not. + :param deferred: Optional bool - this keyword argument is consumed by the + ORM declarative process, and is not part of the :class:`_schema.Column` + itself; instead, it indicates that this column should be "deferred" for + loading as though mapped by :func:`_orm.deferred`. + :param \**kw: All remaining keyword argments are passed through to the + constructor for the :class:`_schema.Column`. """ - return column_property(Column(*args, **kw)) + + return MappedColumn(*args, **kw) def column_property( column: sql.ColumnElement[_T], *additional_columns, **kwargs -) -> "Mapped[_T]": +) -> "ColumnProperty[_T]": r"""Provide a column-level property for use with a mapping. Column-based properties can normally be applied to the mapper's @@ -269,22 +453,49 @@ def column_property( return ColumnProperty(column, *additional_columns, **kwargs) -def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": +@overload +def composite( + class_: Type[_T], + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[_T]": + ... + + +@overload +def composite( + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[Any]": + ... + + +def composite( + class_: Any = None, + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[Any]": r"""Return a composite column-based property for use with a Mapper. See the mapping documentation section :ref:`mapper_composite` for a full usage example. The :class:`.MapperProperty` returned by :func:`.composite` - is the :class:`.CompositeProperty`. + is the :class:`.Composite`. :param class\_: The "composite type" class, or any classmethod or callable which will produce a new instance of the composite object given the column values in order. - :param \*cols: - List of Column objects to be mapped. + :param \*attrs: + List of elements to be mapped, which may include: + + * :class:`_schema.Column` objects + * :func:`_orm.mapped_column` constructs + * string names of other attributes on the mapped class, which may be + any other SQL or object-mapped attribute. This can for + example allow a composite that refers to a many-to-one relationship :param active_history=False: When ``True``, indicates that the "previous" value for a @@ -301,7 +512,7 @@ def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": :func:`~sqlalchemy.orm.deferred`. :param comparator_factory: a class which extends - :class:`.CompositeProperty.Comparator` which provides custom SQL + :class:`.Composite.Comparator` which provides custom SQL clause generation for comparison operations. :param doc: @@ -312,7 +523,7 @@ def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": :attr:`.MapperProperty.info` attribute of this object. """ - return CompositeProperty(class_, *attrs, **kwargs) + return Composite(class_, *attrs, **kwargs) def with_loader_criteria( @@ -500,143 +711,140 @@ def with_loader_criteria( @overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], + secondary=None, + *, + uselist: Literal[False] = None, + collection_class: Literal[None] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[_T]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, uselist: Literal[True] = None, + collection_class: Literal[None] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - order_by=False, - backref=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[Collection[_RC]]: + **kw: Any, +) -> Relationship[List[_T]]: ... @overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, - uselist: Optional[bool] = None, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[List] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - order_by=False, - backref=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[_RC]: + **kw: Any, +) -> Relationship[List[_T]]: ... +@overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[Set] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, + back_populates=None, + **kw: Any, +) -> Relationship[Set[_T]]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]], + secondary=None, + *, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[Mapping[Any, Any]] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Mapping[Any, _T]]: + ... + + +@overload +def relationship( + argument: _RelationshipArgumentType[_T], + secondary=None, + *, + uselist: Literal[None] = None, + collection_class: Literal[None] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]] = None, + secondary=None, + *, + uselist: Literal[True] = None, + collection_class: Any = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload +def relationship( + argument: Literal[None] = None, + secondary=None, + *, uselist: Optional[bool] = None, - order_by=False, - backref=None, + collection_class: Any = None, + primaryjoin=None, + secondaryjoin=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[_RC]: + **kw: Any, +) -> Relationship[Any]: + ... + + +def relationship( + argument: Optional[_RelationshipArgumentType[_T]] = None, + secondary=None, + *, + uselist: Optional[bool] = None, + collection_class: Optional[Type[Collection]] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: """Provide a relationship between two mapped classes. This corresponds to a parent-child or associative table relationship. The constructed class is an instance of - :class:`.RelationshipProperty`. + :class:`.Relationship`. A typical :func:`_orm.relationship`, used in a classical mapping:: @@ -897,7 +1105,7 @@ def relationship( examples. :param comparator_factory: - A class which extends :class:`.RelationshipProperty.Comparator` + A class which extends :class:`.Relationship.Comparator` which provides custom SQL clause generation for comparison operations. @@ -1447,42 +1655,15 @@ def relationship( """ - return RelationshipProperty( + return Relationship( argument, - secondary, - primaryjoin, - secondaryjoin, - foreign_keys, - uselist, - order_by, - backref, - back_populates, - overlaps, - post_update, - cascade, - viewonly, - lazy, - collection_class, - passive_deletes, - passive_updates, - remote_side, - enable_typechecks, - join_depth, - comparator_factory, - single_parent, - innerjoin, - distinct_target_key, - doc, - active_history, - cascade_backrefs, - load_on_pending, - bake_queries, - _local_remote_pairs, - query_class, - info, - omit_join, - sync_backref, - _legacy_inactive_history_style, + secondary=secondary, + uselist=uselist, + collection_class=collection_class, + primaryjoin=primaryjoin, + secondaryjoin=secondaryjoin, + back_populates=back_populates, + **kw, ) @@ -1493,7 +1674,7 @@ def synonym( comparator_factory=None, doc=None, info=None, -) -> "Mapped": +) -> "Synonym[Any]": """Denote an attribute name as a synonym to a mapped property, in that the attribute will mirror the value and expression behavior of another attribute. @@ -1597,9 +1778,7 @@ def synonym( than can be achieved with synonyms. """ - return SynonymProperty( - name, map_column, descriptor, comparator_factory, doc, info - ) + return Synonym(name, map_column, descriptor, comparator_factory, doc, info) def create_session(bind=None, **kwargs): @@ -1733,7 +1912,9 @@ def deferred(*columns, **kw): return ColumnProperty(deferred=True, *columns, **kw) -def query_expression(default_expr=sql.null()): +def query_expression( + default_expr: sql.ColumnElement[_T] = sql.null(), +) -> "Mapped[_T]": """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -1787,3 +1968,273 @@ def clear_mappers(): """ mapperlib._dispose_registries(mapperlib._all_registries(), False) + + +@overload +def aliased( + element: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> "AliasedClass[_T]": + ... + + +@overload +def aliased( + element: "FromClause", + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> "Alias": + ... + + +def aliased( + element: Union[Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"], + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> Union["AliasedClass[_T]", "Alias"]: + """Produce an alias of the given element, usually an :class:`.AliasedClass` + instance. + + E.g.:: + + my_alias = aliased(MyClass) + + session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) + + The :func:`.aliased` function is used to create an ad-hoc mapping of a + mapped class to a new selectable. By default, a selectable is generated + from the normally mapped selectable (typically a :class:`_schema.Table` + ) using the + :meth:`_expression.FromClause.alias` method. However, :func:`.aliased` + can also be + used to link the class to a new :func:`_expression.select` statement. + Also, the :func:`.with_polymorphic` function is a variant of + :func:`.aliased` that is intended to specify a so-called "polymorphic + selectable", that corresponds to the union of several joined-inheritance + subclasses at once. + + For convenience, the :func:`.aliased` function also accepts plain + :class:`_expression.FromClause` constructs, such as a + :class:`_schema.Table` or + :func:`_expression.select` construct. In those cases, the + :meth:`_expression.FromClause.alias` + method is called on the object and the new + :class:`_expression.Alias` object returned. The returned + :class:`_expression.Alias` is not + ORM-mapped in this case. + + .. seealso:: + + :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` + + :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` + + :ref:`ormtutorial_aliases` - in the legacy :ref:`ormtutorial_toplevel` + + :param element: element to be aliased. Is normally a mapped class, + but for convenience can also be a :class:`_expression.FromClause` + element. + + :param alias: Optional selectable unit to map the element to. This is + usually used to link the object to a subquery, and should be an aliased + select construct as one would produce from the + :meth:`_query.Query.subquery` method or + the :meth:`_expression.Select.subquery` or + :meth:`_expression.Select.alias` methods of the :func:`_expression.select` + construct. + + :param name: optional string name to use for the alias, if not specified + by the ``alias`` parameter. The name, among other things, forms the + attribute name that will be accessible via tuples returned by a + :class:`_query.Query` object. Not supported when creating aliases + of :class:`_sql.Join` objects. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. + + :param adapt_on_names: if True, more liberal "matching" will be used when + mapping the mapped columns of the ORM entity to those of the + given selectable - a name-based match will be performed if the + given selectable doesn't otherwise have a column that corresponds + to one on the entity. The use case for this is when associating + an entity with some derived selectable such as one that uses + aggregate functions:: + + class UnitPrice(Base): + __tablename__ = 'unit_price' + ... + unit_id = Column(Integer) + price = Column(Numeric) + + aggregated_unit_price = Session.query( + func.sum(UnitPrice.price).label('price') + ).group_by(UnitPrice.unit_id).subquery() + + aggregated_unit_price = aliased(UnitPrice, + alias=aggregated_unit_price, adapt_on_names=True) + + Above, functions on ``aggregated_unit_price`` which refer to + ``.price`` will return the + ``func.sum(UnitPrice.price).label('price')`` column, as it is + matched on the name "price". Ordinarily, the "price" function + wouldn't have any "column correspondence" to the actual + ``UnitPrice.price`` column as it is not a proxy of the original. + + """ + return AliasedInsp._alias_factory( + element, + alias=alias, + name=name, + flat=flat, + adapt_on_names=adapt_on_names, + ) + + +def with_polymorphic( + base, + classes, + selectable=False, + flat=False, + polymorphic_on=None, + aliased=False, + innerjoin=False, + _use_mapper_path=False, +): + """Produce an :class:`.AliasedClass` construct which specifies + columns for descendant mappers of the given base. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + .. seealso:: + + :ref:`with_polymorphic` - full discussion of + :func:`_orm.with_polymorphic`. + + :param base: Base class to be aliased. + + :param classes: a single class or mapper, or list of + class/mappers, which inherit from the base class. + Alternatively, it may also be the string ``'*'``, in which case + all descending mapped classes will be added to the FROM clause. + + :param aliased: when True, the selectable will be aliased. For a + JOIN, this means the JOIN will be SELECTed from inside of a subquery + unless the :paramref:`_orm.with_polymorphic.flat` flag is set to + True, which is recommended for simpler use cases. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. Setting this flag is + recommended as long as the resulting SQL is functional. + + :param selectable: a table or subquery that will + be used in place of the generated FROM clause. This argument is + required if any of the desired classes use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` argument + must represent the full set of tables and columns mapped by every + mapped class. Otherwise, the unaccounted mapped columns will + result in their table being appended directly to the FROM clause + which will usually lead to incorrect results. + + When left at its default value of ``False``, the polymorphic + selectable assigned to the base mapper is used for selecting rows. + However, it may also be passed as ``None``, which will bypass the + configured polymorphic selectable and instead construct an ad-hoc + selectable for the target classes given; for joined table inheritance + this will be a join that includes all target mappers and their + subclasses. + + :param polymorphic_on: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the base classes' mapper will be used, if any. This + is useful for mappings that don't have polymorphic loading + behavior by default. + + :param innerjoin: if True, an INNER JOIN will be used. This should + only be specified if querying for one specific subtype only + """ + return AliasedInsp._with_polymorphic_factory( + base, + classes, + selectable=selectable, + flat=flat, + polymorphic_on=polymorphic_on, + aliased=aliased, + innerjoin=innerjoin, + _use_mapper_path=_use_mapper_path, + ) + + +def join( + left, right, onclause=None, isouter=False, full=False, join_to_left=None +): + r"""Produce an inner join between left and right clauses. + + :func:`_orm.join` is an extension to the core join interface + provided by :func:`_expression.join()`, where the + left and right selectables may be not only core selectable + objects such as :class:`_schema.Table`, but also mapped classes or + :class:`.AliasedClass` instances. The "on" clause can + be a SQL expression, or an attribute or string name + referencing a configured :func:`_orm.relationship`. + + :func:`_orm.join` is not commonly needed in modern usage, + as its functionality is encapsulated within that of the + :meth:`_query.Query.join` method, which features a + significant amount of automation beyond :func:`_orm.join` + by itself. Explicit usage of :func:`_orm.join` + with :class:`_query.Query` involves usage of the + :meth:`_query.Query.select_from` method, as in:: + + from sqlalchemy.orm import join + session.query(User).\ + select_from(join(User, Address, User.addresses)).\ + filter(Address.email_address=='foo@bar.com') + + In modern SQLAlchemy the above join can be written more + succinctly as:: + + session.query(User).\ + join(User.addresses).\ + filter(Address.email_address=='foo@bar.com') + + See :meth:`_query.Query.join` for information on modern usage + of ORM level joins. + + .. deprecated:: 0.8 + + the ``join_to_left`` parameter is deprecated, and will be removed + in a future release. The parameter has no effect. + + """ + return _ORMJoin(left, right, onclause, isouter, full) + + +def outerjoin(left, right, onclause=None, full=False, join_to_left=None): + """Produce a left outer join between left and right clauses. + + This is the "outer join" version of the :func:`_orm.join` function, + featuring the same behavior except that an OUTER JOIN is generated. + See that function's documentation for other usage details. + + """ + return _ORMJoin(left, right, onclause, True, full) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 5a605b7c6..fbfb2b2ee 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -35,6 +35,7 @@ from .base import instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED from .base import manager_of_class +from .base import Mapped as Mapped # noqa from .base import NEVER_SET # noqa from .base import NO_AUTOFLUSH from .base import NO_CHANGE # noqa @@ -79,6 +80,7 @@ class QueryableAttribute( traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, + roles.ColumnsClauseRole, sql_base.Immutable, sql_base.MemoizedHasCacheKey, ): @@ -190,7 +192,7 @@ class QueryableAttribute( construct has defined one). * If the attribute refers to any other kind of - :class:`.MapperProperty`, including :class:`.RelationshipProperty`, + :class:`.MapperProperty`, including :class:`.Relationship`, the attribute will refer to the :attr:`.MapperProperty.info` dictionary associated with that :class:`.MapperProperty`. @@ -352,7 +354,7 @@ class QueryableAttribute( Return values here will commonly be instances of - :class:`.ColumnProperty` or :class:`.RelationshipProperty`. + :class:`.ColumnProperty` or :class:`.Relationship`. """ diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 7ab4b7737..e6d4a6729 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -12,8 +12,11 @@ import operator import typing from typing import Any +from typing import Callable from typing import Generic +from typing import Optional from typing import overload +from typing import Tuple from typing import TypeVar from typing import Union @@ -22,8 +25,9 @@ from .. import exc as sa_exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations -from ..util import typing as compat_typing from ..util.langhelpers import TypingOnly +from ..util.typing import Concatenate +from ..util.typing import ParamSpec if typing.TYPE_CHECKING: @@ -32,6 +36,9 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T", bound=Any) +_IdentityKeyType = Tuple[type, Tuple[Any, ...], Optional[str]] + + PASSIVE_NO_RESULT = util.symbol( "PASSIVE_NO_RESULT", """Symbol returned by a loader callable or other attribute/history @@ -236,16 +243,16 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_Fn = typing.TypeVar("_Fn", bound=typing.Callable) -_Args = compat_typing.ParamSpec("_Args") -_Self = typing.TypeVar("_Self") +_Fn = TypeVar("_Fn", bound=Callable) +_Args = ParamSpec("_Args") +_Self = TypeVar("_Self") def _assertions( - *assertions, -) -> typing.Callable[ - [typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self]], - typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self], + *assertions: Any, +) -> Callable[ + [Callable[Concatenate[_Self, _Fn, _Args], _Self]], + Callable[Concatenate[_Self, _Fn, _Args], _Self], ]: @util.decorator def generate( @@ -605,8 +612,8 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): ... -class Mapped(Generic[_T], util.TypingOnly): - """Represent an ORM mapped attribute for typing purposes. +class Mapped(Generic[_T], TypingOnly): + """Represent an ORM mapped attribute on a mapped class. This class represents the complete descriptor interface for any class attribute that will have been :term:`instrumented` by the ORM @@ -650,7 +657,7 @@ class Mapped(Generic[_T], util.TypingOnly): ... @classmethod - def _empty_constructor(cls, arg1: Any) -> "SQLORMOperations[_T]": + def _empty_constructor(cls, arg1: Any) -> "Mapped[_T]": ... @overload diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index ac6b0fd4c..037b70257 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -10,11 +10,14 @@ This system allows specification of classes and expressions used in :func:`_orm.relationship` using strings. """ +import re +from typing import MutableMapping +from typing import Union import weakref from . import attributes from . import interfaces -from .descriptor_props import SynonymProperty +from .descriptor_props import Synonym from .properties import ColumnProperty from .util import class_mapper from .. import exc @@ -22,6 +25,8 @@ from .. import inspection from .. import util from ..sql.schema import _get_table_key +_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] + # strong references to registries which we place in # the _decl_class_registry, which is usually weak referencing. # the internal registries here link to classes with weakrefs and remove @@ -118,7 +123,13 @@ def _key_is_empty(key, decl_class_registry, test): return not test(thing) -class _MultipleClassMarker: +class ClsRegistryToken: + """an object that can be in the registry._class_registry as a value.""" + + __slots__ = () + + +class _MultipleClassMarker(ClsRegistryToken): """refers to multiple classes of the same name within _decl_class_registry. @@ -182,7 +193,7 @@ class _MultipleClassMarker: self.contents.add(weakref.ref(item, self._remove_item)) -class _ModuleMarker: +class _ModuleMarker(ClsRegistryToken): """Refers to a module name within _decl_class_registry. @@ -281,7 +292,7 @@ class _GetColumns: desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NOT_EXTENSION: prop = desc.property - if isinstance(prop, SynonymProperty): + if isinstance(prop, Synonym): key = prop.name elif not isinstance(prop, ColumnProperty): raise exc.InvalidRequestError( @@ -372,13 +383,26 @@ class _class_resolver: return self.fallback[key] def _raise_for_name(self, name, err): - raise exc.InvalidRequestError( - "When initializing mapper %s, expression %r failed to " - "locate a name (%r). If this is a class name, consider " - "adding this relationship() to the %r class after " - "both dependent classes have been defined." - % (self.prop.parent, self.arg, name, self.cls) - ) from err + generic_match = re.match(r"(.+)\[(.+)\]", name) + + if generic_match: + raise exc.InvalidRequestError( + f"When initializing mapper {self.prop.parent}, " + f'expression "relationship({self.arg!r})" seems to be ' + "using a generic class as the argument to relationship(); " + "please state the generic argument " + "using an annotation, e.g. " + f'"{self.prop.key}: Mapped[{generic_match.group(1)}' + f'[{generic_match.group(2)}]] = relationship()"' + ) from err + else: + raise exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." + % (self.prop.parent, self.arg, name, self.cls) + ) from err def _resolve_name(self): name = self.arg diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 75ce8216f..ba4225563 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -102,18 +102,20 @@ The owning object and :class:`.CollectionAttributeImpl` are also reachable through the adapter, allowing for some very sophisticated behavior. """ - import operator import threading +import typing import weakref -from sqlalchemy.util.compat import inspect_getfullargspec -from . import base from .. import exc as sa_exc from .. import util -from ..sql import coercions -from ..sql import expression -from ..sql import roles +from ..util.compat import inspect_getfullargspec + +if typing.TYPE_CHECKING: + from .mapped_collection import attribute_mapped_collection + from .mapped_collection import column_mapped_collection + from .mapped_collection import mapped_collection + from .mapped_collection import MappedCollection # noqa: F401 __all__ = [ "collection", @@ -126,180 +128,6 @@ __all__ = [ __instrumentation_mutex = threading.Lock() -class _PlainColumnGetter: - """Plain column getter, stores collection of Column objects - directly. - - Serializes to a :class:`._SerializableColumnGetterV2` - which has more expensive __call__() performance - and some rare caveats. - - """ - - def __init__(self, cols): - self.cols = cols - self.composite = len(cols) > 1 - - def __reduce__(self): - return _SerializableColumnGetterV2._reduce_from_cols(self.cols) - - def _cols(self, mapper): - return self.cols - - def __call__(self, value): - state = base.instance_state(value) - m = base._state_mapper(state) - - key = [ - m._get_state_attr_by_column(state, state.dict, col) - for col in self._cols(m) - ] - - if self.composite: - return tuple(key) - else: - return key[0] - - -class _SerializableColumnGetter: - """Column-based getter used in version 0.7.6 only. - - Remains here for pickle compatibility with 0.7.6. - - """ - - def __init__(self, colkeys): - self.colkeys = colkeys - self.composite = len(colkeys) > 1 - - def __reduce__(self): - return _SerializableColumnGetter, (self.colkeys,) - - def __call__(self, value): - state = base.instance_state(value) - m = base._state_mapper(state) - key = [ - m._get_state_attr_by_column( - state, state.dict, m.mapped_table.columns[k] - ) - for k in self.colkeys - ] - if self.composite: - return tuple(key) - else: - return key[0] - - -class _SerializableColumnGetterV2(_PlainColumnGetter): - """Updated serializable getter which deals with - multi-table mapped classes. - - Two extremely unusual cases are not supported. - Mappings which have tables across multiple metadata - objects, or which are mapped to non-Table selectables - linked across inheriting mappers may fail to function - here. - - """ - - def __init__(self, colkeys): - self.colkeys = colkeys - self.composite = len(colkeys) > 1 - - def __reduce__(self): - return self.__class__, (self.colkeys,) - - @classmethod - def _reduce_from_cols(cls, cols): - def _table_key(c): - if not isinstance(c.table, expression.TableClause): - return None - else: - return c.table.key - - colkeys = [(c.key, _table_key(c)) for c in cols] - return _SerializableColumnGetterV2, (colkeys,) - - def _cols(self, mapper): - cols = [] - metadata = getattr(mapper.local_table, "metadata", None) - for (ckey, tkey) in self.colkeys: - if tkey is None or metadata is None or tkey not in metadata: - cols.append(mapper.local_table.c[ckey]) - else: - cols.append(metadata.tables[tkey].c[ckey]) - return cols - - -def column_mapped_collection(mapping_spec): - """A dictionary-based collection type with column-based keying. - - Returns a :class:`.MappedCollection` factory with a keying function - generated from mapping_spec, which may be a Column or a sequence - of Columns. - - The key value must be immutable for the lifetime of the object. You - can not, for example, map on foreign key values if those key values will - change during the session, i.e. from None to a database-assigned integer - after a session flush. - - """ - cols = [ - coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") - for q in util.to_list(mapping_spec) - ] - keyfunc = _PlainColumnGetter(cols) - return lambda: MappedCollection(keyfunc) - - -class _SerializableAttrGetter: - def __init__(self, name): - self.name = name - self.getter = operator.attrgetter(name) - - def __call__(self, target): - return self.getter(target) - - def __reduce__(self): - return _SerializableAttrGetter, (self.name,) - - -def attribute_mapped_collection(attr_name): - """A dictionary-based collection type with attribute-based keying. - - Returns a :class:`.MappedCollection` factory with a keying based on the - 'attr_name' attribute of entities in the collection, where ``attr_name`` - is the string name of the attribute. - - .. warning:: the key value must be assigned to its final value - **before** it is accessed by the attribute mapped collection. - Additionally, changes to the key attribute are **not tracked** - automatically, which means the key in the dictionary is not - automatically synchronized with the key value on the target object - itself. See the section :ref:`key_collections_mutations` - for an example. - - """ - getter = _SerializableAttrGetter(attr_name) - return lambda: MappedCollection(getter) - - -def mapped_collection(keyfunc): - """A dictionary-based collection type with arbitrary keying. - - Returns a :class:`.MappedCollection` factory with a keying function - generated from keyfunc, a callable that takes an entity and returns a - key value. - - The key value must be immutable for the lifetime of the object. You - can not, for example, map on foreign key values if those key values will - change during the session, i.e. from None to a database-assigned integer - after a session flush. - - """ - return lambda: MappedCollection(keyfunc) - - class collection: """Decorators for entity collection classes. @@ -1620,63 +1448,24 @@ __interfaces = { } -class MappedCollection(dict): - """A basic dictionary-based collection class. - - Extends dict with the minimal bag semantics that collection - classes require. ``set`` and ``remove`` are implemented in terms - of a keying function: any callable that takes an object and - returns an object for use as a dictionary key. - - """ - - def __init__(self, keyfunc): - """Create a new collection with keying provided by keyfunc. +def __go(lcls): - keyfunc may be any callable that takes an object and returns an object - for use as a dictionary key. + global mapped_collection, column_mapped_collection + global attribute_mapped_collection, MappedCollection - The keyfunc will be called every time the ORM needs to add a member by - value-only (such as when loading instances from the database) or - remove a member. The usual cautions about dictionary keying apply- - ``keyfunc(object)`` should return the same output for the life of the - collection. Keying based on mutable properties can result in - unreachable instances "lost" in the collection. + from .mapped_collection import mapped_collection + from .mapped_collection import column_mapped_collection + from .mapped_collection import attribute_mapped_collection + from .mapped_collection import MappedCollection - """ - self.keyfunc = keyfunc - - @collection.appender - @collection.internally_instrumented - def set(self, value, _sa_initiator=None): - """Add an item by value, consulting the keyfunc for the key.""" - - key = self.keyfunc(value) - self.__setitem__(key, value, _sa_initiator) - - @collection.remover - @collection.internally_instrumented - def remove(self, value, _sa_initiator=None): - """Remove an item by value, consulting the keyfunc for the key.""" - - key = self.keyfunc(value) - # Let self[key] raise if key is not in this collection - # testlib.pragma exempt:__ne__ - if self[key] != value: - raise sa_exc.InvalidRequestError( - "Can not remove '%s': collection holds '%s' for key '%s'. " - "Possible cause: is the MappedCollection key function " - "based on mutable properties or properties that only obtain " - "values after flush?" % (value, self[key], key) - ) - self.__delitem__(key, _sa_initiator) + # ensure instrumentation is associated with + # these built-in classes; if a user-defined class + # subclasses these and uses @internally_instrumented, + # the superclass is otherwise not instrumented. + # see [ticket:2406]. + _instrument_class(InstrumentedList) + _instrument_class(InstrumentedSet) + _instrument_class(MappedCollection) -# ensure instrumentation is associated with -# these built-in classes; if a user-defined class -# subclasses these and uses @internally_instrumented, -# the superclass is otherwise not instrumented. -# see [ticket:2406]. -_instrument_class(MappedCollection) -_instrument_class(InstrumentedList) -_instrument_class(InstrumentedSet) +__go(locals()) diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 8e9cf66e2..34f291864 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -5,16 +5,18 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php import itertools +from typing import List from . import attributes from . import interfaces from . import loading from .base import _is_aliased_class +from .interfaces import ORMColumnDescription from .interfaces import ORMColumnsClauseRole from .path_registry import PathRegistry from .util import _entity_corresponds_to from .util import _ORMJoin -from .util import aliased +from .util import AliasedClass from .util import Bundle from .util import ORMAdapter from .. import exc as sa_exc @@ -1570,7 +1572,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # when we are here, it means join() was called with an indicator # as to an exact left side, which means a path to a - # RelationshipProperty was given, e.g.: + # Relationship was given, e.g.: # # join(RightEntity, LeftEntity.right) # @@ -1725,7 +1727,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): need_adapter = True # make the right hand side target into an ORM entity - right = aliased(right_mapper, right_selectable) + right = AliasedClass(right_mapper, right_selectable) util.warn_deprecated( "An alias is being generated automatically against " @@ -1750,7 +1752,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # test/orm/inheritance/test_relationships.py. There are also # general overlap cases with many-to-many tables where automatic # aliasing is desirable. - right = aliased(right, flat=True) + right = AliasedClass(right, flat=True) need_adapter = True util.warn( @@ -1910,7 +1912,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _column_descriptions( query_or_select_stmt, compile_state=None, legacy=False -): +) -> List[ORMColumnDescription]: if compile_state is None: compile_state = ORMSelectCompileState._create_entities_collection( query_or_select_stmt, legacy=legacy diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 59fabb9b6..5ac9966dd 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -11,7 +11,9 @@ import typing from typing import Any from typing import Callable from typing import ClassVar +from typing import Mapping from typing import Optional +from typing import Type from typing import TypeVar from typing import Union import weakref @@ -31,7 +33,7 @@ from .decl_base import _declarative_constructor from .decl_base import _DeferredMapperConfig from .decl_base import _del_attribute from .decl_base import _mapper -from .descriptor_props import SynonymProperty as _orm_synonym +from .descriptor_props import Synonym as _orm_synonym from .mapper import Mapper from .. import exc from .. import inspection @@ -39,14 +41,18 @@ from .. import util from ..sql.elements import SQLCoreOperations from ..sql.schema import MetaData from ..sql.selectable import FromClause +from ..sql.type_api import TypeEngine from ..util import hybridmethod from ..util import hybridproperty +from ..util import typing as compat_typing if typing.TYPE_CHECKING: from .state import InstanceState # noqa _T = TypeVar("_T", bound=Any) +_TypeAnnotationMapType = Mapping[Type, Union[Type[TypeEngine], TypeEngine]] + def has_inherited_table(cls): """Given a class, return True if any of the classes it inherits from has a @@ -67,8 +73,22 @@ def has_inherited_table(cls): return False +class _DynamicAttributesType(type): + def __setattr__(cls, key, value): + if "__mapper__" in cls.__dict__: + _add_attribute(cls, key, value) + else: + type.__setattr__(cls, key, value) + + def __delattr__(cls, key): + if "__mapper__" in cls.__dict__: + _del_attribute(cls, key) + else: + type.__delattr__(cls, key) + + class DeclarativeAttributeIntercept( - type, inspection.Inspectable["Mapper[Any]"] + _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"] ): """Metaclass that may be used in conjunction with the :class:`_orm.DeclarativeBase` class to support addition of class @@ -76,15 +96,16 @@ class DeclarativeAttributeIntercept( """ - def __setattr__(cls, key, value): - _add_attribute(cls, key, value) - - def __delattr__(cls, key): - _del_attribute(cls, key) +class DeclarativeMeta( + _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"] +): + metadata: MetaData + registry: "RegistryType" -class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]): - def __init__(cls, classname, bases, dict_, **kw): + def __init__( + cls, classname: Any, bases: Any, dict_: Any, **kw: Any + ) -> None: # early-consume registry from the initial declarative base, # assign privately to not conflict with subclass attributes named # "registry" @@ -103,12 +124,6 @@ class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]): _as_declarative(reg, cls, dict_) type.__init__(cls, classname, bases, dict_) - def __setattr__(cls, key, value): - _add_attribute(cls, key, value) - - def __delattr__(cls, key): - _del_attribute(cls, key) - def synonym_for(name, map_column=False): """Decorator that produces an :func:`_orm.synonym` @@ -250,6 +265,9 @@ class declared_attr(interfaces._MappedAttribute[_T]): self._cascading = cascading self.__doc__ = fn.__doc__ + def _collect_return_annotation(self) -> Optional[Type[Any]]: + return util.get_annotations(self.fget).get("return") + def __get__(self, instance, owner) -> InstrumentedAttribute[_T]: # the declared_attr needs to make use of a cache that exists # for the span of the declarative scan_attributes() phase. @@ -409,6 +427,11 @@ def _setup_declarative_base(cls): else: metadata = None + if "type_annotation_map" in cls.__dict__: + type_annotation_map = cls.__dict__["type_annotation_map"] + else: + type_annotation_map = None + reg = cls.__dict__.get("registry", None) if reg is not None: if not isinstance(reg, registry): @@ -416,8 +439,18 @@ def _setup_declarative_base(cls): "Declarative base class has a 'registry' attribute that is " "not an instance of sqlalchemy.orm.registry()" ) + elif type_annotation_map is not None: + raise exc.InvalidRequestError( + "Declarative base class has both a 'registry' attribute and a " + "type_annotation_map entry. Per-base type_annotation_maps " + "are not supported. Please apply the type_annotation_map " + "to this registry directly." + ) + else: - reg = registry(metadata=metadata) + reg = registry( + metadata=metadata, type_annotation_map=type_annotation_map + ) cls.registry = reg cls._sa_registry = reg @@ -476,6 +509,44 @@ class DeclarativeBase( mappings. The superclass makes use of the ``__init_subclass__()`` method to set up new classes and metaclasses aren't used. + When first used, the :class:`_orm.DeclarativeBase` class instantiates a new + :class:`_orm.registry` to be used with the base, assuming one was not + provided explicitly. The :class:`_orm.DeclarativeBase` class supports + class-level attributes which act as parameters for the construction of this + registry; such as to indicate a specific :class:`_schema.MetaData` + collection as well as a specific value for + :paramref:`_orm.registry.type_annotation_map`:: + + from typing import Annotation + + from sqlalchemy import BigInteger + from sqlalchemy import MetaData + from sqlalchemy import String + from sqlalchemy.orm import DeclarativeBase + + bigint = Annotation(int, "bigint") + my_metadata = MetaData() + + class Base(DeclarativeBase): + metadata = my_metadata + type_annotation_map = { + str: String().with_variant(String(255), "mysql", "mariadb"), + bigint: BigInteger() + } + + Class-level attributes which may be specified include: + + :param metadata: optional :class:`_schema.MetaData` collection. + If a :class:`_orm.registry` is constructed automatically, this + :class:`_schema.MetaData` collection will be used to construct it. + Otherwise, the local :class:`_schema.MetaData` collection will supercede + that used by an existing :class:`_orm.registry` passed using the + :paramref:`_orm.DeclarativeBase.registry` parameter. + :param type_annotation_map: optional type annotation map that will be + passed to the :class:`_orm.registry` as + :paramref:`_orm.registry.type_annotation_map`. + :param registry: supply a pre-existing :class:`_orm.registry` directly. + .. versionadded:: 2.0 """ @@ -516,12 +587,13 @@ def add_mapped_attribute(target, key, attr): def declarative_base( - metadata=None, + metadata: Optional[MetaData] = None, mapper=None, cls=object, name="Base", - constructor=_declarative_constructor, - class_registry=None, + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, metaclass=DeclarativeMeta, ) -> Any: r"""Construct a base class for declarative class definitions. @@ -593,6 +665,14 @@ def declarative_base( to share the same registry of class names for simplified inter-base relationships. + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + .. versionadded:: 2.0 + :param metaclass: Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__ compatible callable to use as the meta type of the generated @@ -608,6 +688,7 @@ def declarative_base( metadata=metadata, class_registry=class_registry, constructor=constructor, + type_annotation_map=type_annotation_map, ).generate_base( mapper=mapper, cls=cls, @@ -651,9 +732,10 @@ class registry: def __init__( self, - metadata=None, - class_registry=None, - constructor=_declarative_constructor, + metadata: Optional[MetaData] = None, + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, ): r"""Construct a new :class:`_orm.registry` @@ -679,6 +761,14 @@ class registry: to share the same registry of class names for simplified inter-base relationships. + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + .. versionadded:: 2.0 + """ lcl_metadata = metadata or MetaData() @@ -690,7 +780,9 @@ class registry: self._non_primary_mappers = weakref.WeakKeyDictionary() self.metadata = lcl_metadata self.constructor = constructor - + self.type_annotation_map = {} + if type_annotation_map is not None: + self.update_type_annotation_map(type_annotation_map) self._dependents = set() self._dependencies = set() @@ -699,6 +791,25 @@ class registry: with mapperlib._CONFIGURE_MUTEX: mapperlib._mapper_registries[self] = True + def update_type_annotation_map( + self, + type_annotation_map: Mapping[ + Type, Union[Type[TypeEngine], TypeEngine] + ], + ) -> None: + """update the :paramref:`_orm.registry.type_annotation_map` with new + values.""" + + self.type_annotation_map.update( + { + sub_type: sqltype + for typ, sqltype in type_annotation_map.items() + for sub_type in compat_typing.expand_unions( + typ, include_union=True, discard_none=True + ) + } + ) + @property def mappers(self): """read only collection of all :class:`_orm.Mapper` objects.""" @@ -1131,6 +1242,9 @@ class registry: return _mapper(self, class_, local_table, kw) +RegistryType = registry + + def as_declarative(**kw): """ Class decorator which will adapt a given class into a diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index fb736806c..342aa772b 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -5,23 +5,34 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php """Internal implementation for declarative.""" + +from __future__ import annotations + import collections +from typing import Any +from typing import Dict +from typing import Tuple import weakref -from sqlalchemy.orm import attributes -from sqlalchemy.orm import instrumentation +from . import attributes from . import clsregistry from . import exc as orm_exc +from . import instrumentation from . import mapperlib from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .base import _is_mapped_class from .base import InspectionAttr -from .descriptor_props import CompositeProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MappedAttribute +from .interfaces import _MapsColumns from .interfaces import MapperProperty from .mapper import Mapper as mapper from .properties import ColumnProperty +from .properties import MappedColumn +from .util import _is_mapped_annotation from .util import class_mapper from .. import event from .. import exc @@ -130,7 +141,7 @@ def _mapper(registry, cls, table, mapper_kw): @util.preload_module("sqlalchemy.orm.decl_api") -def _is_declarative_props(obj): +def _is_declarative_props(obj: Any) -> bool: declared_attr = util.preloaded.orm_decl_api.declared_attr return isinstance(obj, (declared_attr, util.classproperty)) @@ -208,7 +219,7 @@ class _MapperConfig: class _ImperativeMapperConfig(_MapperConfig): - __slots__ = ("dict_", "local_table", "inherits") + __slots__ = ("local_table", "inherits") def __init__( self, @@ -221,7 +232,6 @@ class _ImperativeMapperConfig(_MapperConfig): registry, cls_, mapper_kw ) - self.dict_ = {} self.local_table = self.set_cls_attribute("__table__", table) with mapperlib._CONFIGURE_MUTEX: @@ -277,7 +287,10 @@ class _ImperativeMapperConfig(_MapperConfig): class _ClassScanMapperConfig(_MapperConfig): __slots__ = ( - "dict_", + "registry", + "clsdict_view", + "collected_attributes", + "collected_annotations", "local_table", "persist_selectable", "declared_columns", @@ -299,11 +312,17 @@ class _ClassScanMapperConfig(_MapperConfig): ): super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw) - - self.dict_ = dict(dict_) if dict_ else {} + self.registry = registry self.persist_selectable = None - self.declared_columns = set() + + self.clsdict_view = ( + util.immutabledict(dict_) if dict_ else util.EMPTY_DICT + ) + self.collected_attributes = {} + self.collected_annotations: Dict[str, Tuple[Any, bool]] = {} + self.declared_columns = util.OrderedSet() self.column_copies = {} + self._setup_declared_events() self._scan_attributes() @@ -407,6 +426,19 @@ class _ClassScanMapperConfig(_MapperConfig): return attribute_is_overridden + _skip_attrs = frozenset( + [ + "__module__", + "__annotations__", + "__doc__", + "__dict__", + "__weakref__", + "_sa_class_manager", + "__dict__", + "__weakref__", + ] + ) + def _cls_attr_resolver(self, cls): """produce a function to iterate the "attributes" of a class, adjusting for SQLAlchemy fields embedded in dataclass fields. @@ -416,31 +448,52 @@ class _ClassScanMapperConfig(_MapperConfig): cls, "__sa_dataclass_metadata_key__", None ) + cls_annotations = util.get_annotations(cls) + + cls_vars = vars(cls) + + skip = self._skip_attrs + + names = util.merge_lists_w_ordering( + [n for n in cls_vars if n not in skip], list(cls_annotations) + ) if sa_dataclass_metadata_key is None: def local_attributes_for_class(): - for name, obj in vars(cls).items(): - yield name, obj, False + return ( + ( + name, + cls_vars.get(name), + cls_annotations.get(name), + False, + ) + for name in names + ) else: - field_names = set() + dataclass_fields = { + field.name: field for field in util.local_dataclass_fields(cls) + } def local_attributes_for_class(): - for field in util.local_dataclass_fields(cls): - if sa_dataclass_metadata_key in field.metadata: - field_names.add(field.name) + for name in names: + field = dataclass_fields.get(name, None) + if field and sa_dataclass_metadata_key in field.metadata: yield field.name, _as_dc_declaredattr( field.metadata, sa_dataclass_metadata_key - ), True - for name, obj in vars(cls).items(): - if name not in field_names: - yield name, obj, False + ), cls_annotations.get(field.name), True + else: + yield name, cls_vars.get(name), cls_annotations.get( + name + ), False return local_attributes_for_class def _scan_attributes(self): cls = self.cls - dict_ = self.dict_ + + clsdict_view = self.clsdict_view + collected_attributes = self.collected_attributes column_copies = self.column_copies mapper_args_fn = None table_args = inherited_table_args = None @@ -462,10 +515,16 @@ class _ClassScanMapperConfig(_MapperConfig): if not class_mapped and base is not cls: self._produce_column_copies( - local_attributes_for_class, attribute_is_overridden + local_attributes_for_class, + attribute_is_overridden, ) - for name, obj, is_dataclass in local_attributes_for_class(): + for ( + name, + obj, + annotation, + is_dataclass, + ) in local_attributes_for_class(): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -514,7 +573,12 @@ class _ClassScanMapperConfig(_MapperConfig): elif base is not cls: # we're a mixin, abstract base, or something that is # acting like that for now. - if isinstance(obj, Column): + + if isinstance(obj, (Column, MappedColumn)): + self.collected_annotations[name] = ( + annotation, + False, + ) # already copied columns to the mapped class. continue elif isinstance(obj, MapperProperty): @@ -526,8 +590,12 @@ class _ClassScanMapperConfig(_MapperConfig): "field() objects, use a lambda:" ) elif _is_declarative_props(obj): + # tried to get overloads to tell this to + # pylance, no luck + assert obj is not None + if obj._cascading: - if name in dict_: + if name in clsdict_view: # unfortunately, while we can use the user- # defined attribute here to allow a clean # override, if there's another @@ -541,7 +609,7 @@ class _ClassScanMapperConfig(_MapperConfig): "@declared_attr.cascading; " "skipping" % (name, cls) ) - dict_[name] = column_copies[ + collected_attributes[name] = column_copies[ obj ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) @@ -579,19 +647,36 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret = ret.descriptor - dict_[name] = column_copies[obj] = ret + collected_attributes[name] = column_copies[ + obj + ] = ret if ( isinstance(ret, (Column, MapperProperty)) and ret.doc is None ): ret.doc = obj.__doc__ - # here, the attribute is some other kind of property that - # we assume is not part of the declarative mapping. - # however, check for some more common mistakes + + self.collected_annotations[name] = ( + obj._collect_return_annotation(), + False, + ) + elif _is_mapped_annotation(annotation, cls): + self.collected_annotations[name] = ( + annotation, + is_dataclass, + ) + if obj is None: + collected_attributes[name] = MappedColumn() + else: + collected_attributes[name] = obj else: + # here, the attribute is some other kind of + # property that we assume is not part of the + # declarative mapping. however, check for some + # more common mistakes self._warn_for_decl_attributes(base, name, obj) elif is_dataclass and ( - name not in dict_ or dict_[name] is not obj + name not in clsdict_view or clsdict_view[name] is not obj ): # here, we are definitely looking at the target class # and not a superclass. this is currently a @@ -606,7 +691,20 @@ class _ClassScanMapperConfig(_MapperConfig): if _is_declarative_props(obj): obj = obj.fget() - dict_[name] = obj + collected_attributes[name] = obj + self.collected_annotations[name] = ( + annotation, + True, + ) + else: + self.collected_annotations[name] = ( + annotation, + False, + ) + if obj is None and _is_mapped_annotation(annotation, cls): + collected_attributes[name] = MappedColumn() + elif name in clsdict_view: + collected_attributes[name] = obj if inherited_table_args and not tablename: table_args = None @@ -618,46 +716,55 @@ class _ClassScanMapperConfig(_MapperConfig): def _warn_for_decl_attributes(self, cls, key, c): if isinstance(c, expression.ColumnClause): util.warn( - "Attribute '%s' on class %s appears to be a non-schema " - "'sqlalchemy.sql.column()' " + f"Attribute '{key}' on class {cls} appears to " + "be a non-schema 'sqlalchemy.sql.column()' " "object; this won't be part of the declarative mapping" - % (key, cls) ) def _produce_column_copies( self, attributes_for_class, attribute_is_overridden ): cls = self.cls - dict_ = self.dict_ + dict_ = self.clsdict_view + collected_attributes = self.collected_attributes column_copies = self.column_copies # copy mixin columns to the mapped class - for name, obj, is_dataclass in attributes_for_class(): - if isinstance(obj, Column): + for name, obj, annotation, is_dataclass in attributes_for_class(): + if isinstance(obj, (Column, MappedColumn)): if attribute_is_overridden(name, obj): # if column has been overridden # (like by the InstrumentedAttribute of the # superclass), skip continue - elif obj.foreign_keys: - raise exc.InvalidRequestError( - "Columns with foreign keys to other columns " - "must be declared as @declared_attr callables " - "on declarative mixin classes. For dataclass " - "field() objects, use a lambda:." - ) elif name not in dict_ and not ( "__table__" in dict_ and (obj.name or name) in dict_["__table__"].c ): + if obj.foreign_keys: + for fk in obj.foreign_keys: + if ( + fk._table_column is not None + and fk._table_column.table is None + ): + raise exc.InvalidRequestError( + "Columns with foreign keys to " + "non-table-bound " + "columns must be declared as " + "@declared_attr callables " + "on declarative mixin classes. " + "For dataclass " + "field() objects, use a lambda:." + ) + column_copies[obj] = copy_ = obj._copy() - copy_._creation_order = obj._creation_order + collected_attributes[name] = copy_ + setattr(cls, name, copy_) - dict_[name] = copy_ def _extract_mappable_attributes(self): cls = self.cls - dict_ = self.dict_ + collected_attributes = self.collected_attributes our_stuff = self.properties @@ -665,13 +772,17 @@ class _ClassScanMapperConfig(_MapperConfig): cls, "_sa_decl_prepare_nocascade", strict=True ) - for k in list(dict_): + for k in list(collected_attributes): if k in ("__table__", "__tablename__", "__mapper_args__"): continue - value = dict_[k] + value = collected_attributes[k] + if _is_declarative_props(value): + # @declared_attr in collected_attributes only occurs here for a + # @declared_attr that's directly on the mapped class; + # for a mixin, these have already been evaluated if value._cascading: util.warn( "Use of @declared_attr.cascading only applies to " @@ -689,13 +800,13 @@ class _ClassScanMapperConfig(_MapperConfig): ): # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() - value = SynonymProperty(value.key) + value = Synonym(value.key) setattr(cls, k, value) if ( isinstance(value, tuple) and len(value) == 1 - and isinstance(value[0], (Column, MapperProperty)) + and isinstance(value[0], (Column, _MappedAttribute)) ): util.warn( "Ignoring declarative-like tuple value of attribute " @@ -703,12 +814,12 @@ class _ClassScanMapperConfig(_MapperConfig): "accidentally placed at the end of the line?" % k ) continue - elif not isinstance(value, (Column, MapperProperty)): + elif not isinstance(value, (Column, MapperProperty, _MapsColumns)): # using @declared_attr for some object that - # isn't Column/MapperProperty; remove from the dict_ + # isn't Column/MapperProperty; remove from the clsdict_view # and place the evaluated value onto the class. if not k.startswith("__"): - dict_.pop(k) + collected_attributes.pop(k) self._warn_for_decl_attributes(cls, k, value) if not late_mapped: setattr(cls, k, value) @@ -722,27 +833,37 @@ class _ClassScanMapperConfig(_MapperConfig): "for the MetaData instance when using a " "declarative base class." ) + elif isinstance(value, _IntrospectsAnnotations): + annotation, is_dataclass = self.collected_annotations.get( + k, (None, None) + ) + value.declarative_scan( + self.registry, cls, k, annotation, is_dataclass + ) our_stuff[k] = value def _extract_declared_columns(self): our_stuff = self.properties - # set up attributes in the order they were created - util.sort_dictionary( - our_stuff, key=lambda key: our_stuff[key]._creation_order - ) - # extract columns from the class dict declared_columns = self.declared_columns name_to_prop_key = collections.defaultdict(set) for key, c in list(our_stuff.items()): - if isinstance(c, (ColumnProperty, CompositeProperty)): - for col in c.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - if not isinstance(c, CompositeProperty): - name_to_prop_key[col.name].add(key) - declared_columns.add(col) + if isinstance(c, _MapsColumns): + for col in c.columns_to_assign: + if not isinstance(c, Composite): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + + # remove object from the dictionary that will be passed + # as mapper(properties={...}) if it is not a MapperProperty + # (i.e. this currently means it's a MappedColumn) + mp_to_assign = c.mapper_property_to_assign + if mp_to_assign: + our_stuff[key] = mp_to_assign + else: + del our_stuff[key] + elif isinstance(c, Column): _undefer_column_name(key, c) name_to_prop_key[c.name].add(key) @@ -769,16 +890,12 @@ class _ClassScanMapperConfig(_MapperConfig): cls = self.cls tablename = self.tablename table_args = self.table_args - dict_ = self.dict_ + clsdict_view = self.clsdict_view declared_columns = self.declared_columns manager = attributes.manager_of_class(cls) - declared_columns = self.declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order - ) - - if "__table__" not in dict_ and table is None: + if "__table__" not in clsdict_view and table is None: if hasattr(cls, "__table_cls__"): table_cls = util.unbound_method_to_callable(cls.__table_cls__) else: @@ -796,11 +913,11 @@ class _ClassScanMapperConfig(_MapperConfig): else: args = table_args - autoload_with = dict_.get("__autoload_with__") + autoload_with = clsdict_view.get("__autoload_with__") if autoload_with: table_kw["autoload_with"] = autoload_with - autoload = dict_.get("__autoload__") + autoload = clsdict_view.get("__autoload__") if autoload: table_kw["autoload"] = True @@ -1095,18 +1212,21 @@ def _add_attribute(cls, key, value): _undefer_column_name(key, value) cls.__table__.append_column(value, replace_existing=True) cls.__mapper__.add_property(key, value) - elif isinstance(value, ColumnProperty): - for col in value.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - cls.__table__.append_column(col, replace_existing=True) - cls.__mapper__.add_property(key, value) + elif isinstance(value, _MapsColumns): + mp = value.mapper_property_to_assign + for col in value.columns_to_assign: + _undefer_column_name(key, col) + cls.__table__.append_column(col, replace_existing=True) + if not mp: + cls.__mapper__.add_property(key, col) + if mp: + cls.__mapper__.add_property(key, mp) elif isinstance(value, MapperProperty): cls.__mapper__.add_property(key, value) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() - value = SynonymProperty(value.key) + value = Synonym(value.key) cls.__mapper__.add_property(key, value) else: type.__setattr__(cls, key, value) @@ -1124,7 +1244,7 @@ def _del_attribute(cls, key): ): value = cls.__dict__[key] if isinstance( - value, (Column, ColumnProperty, MapperProperty, QueryableAttribute) + value, (Column, _MapsColumns, MapperProperty, QueryableAttribute) ): raise NotImplementedError( "Can't un-map individual mapped attributes on a mapped class." diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 5e67b64cd..4526a8b33 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -10,14 +10,26 @@ that exist as configurational elements, but don't participate as actively in the load/persist ORM loop. """ +import inspect +import itertools +import operator +import typing from typing import Any -from typing import Type +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple from typing import TypeVar +from typing import Union from . import attributes from . import util as orm_util +from .base import Mapped +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator +from .util import _extract_mapped_subtype from .util import _none_set from .. import event from .. import exc as sa_exc @@ -27,6 +39,9 @@ from .. import util from ..sql import expression from ..sql import operators +if typing.TYPE_CHECKING: + from .properties import MappedColumn + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -92,30 +107,48 @@ class DescriptorProperty(MapperProperty[_T]): mapper.class_manager.instrument_attribute(self.key, proxy_attr) -class CompositeProperty(DescriptorProperty[_T]): +class Composite( + _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T] +): """Defines a "composite" mapped attribute, representing a collection of columns as one attribute. - :class:`.CompositeProperty` is constructed using the :func:`.composite` + :class:`.Composite` is constructed using the :func:`.composite` function. + .. versionchanged:: 2.0 Renamed :class:`_orm.CompositeProperty` + to :class:`_orm.Composite`. The old name + :class:`_orm.CompositeProperty` remains as an alias. + .. seealso:: :ref:`mapper_composite` """ - def __init__(self, class_: Type[_T], *attrs, **kwargs): - super(CompositeProperty, self).__init__() + composite_class: Union[type, Callable[..., type]] + attrs: Tuple[ + Union[sql.ColumnElement[Any], "MappedColumn", str, Mapped[Any]], ... + ] + + def __init__(self, class_=None, *attrs, **kwargs): + super().__init__() + + if isinstance(class_, (Mapped, str, sql.ColumnElement)): + self.attrs = (class_,) + attrs + # will initialize within declarative_scan + self.composite_class = None # type: ignore + else: + self.composite_class = class_ + self.attrs = attrs - self.attrs = attrs - self.composite_class = class_ self.active_history = kwargs.get("active_history", False) self.deferred = kwargs.get("deferred", False) self.group = kwargs.get("group", None) self.comparator_factory = kwargs.pop( "comparator_factory", self.__class__.Comparator ) + self._generated_composite_accessor = None if "info" in kwargs: self.info = kwargs.pop("info") @@ -123,11 +156,26 @@ class CompositeProperty(DescriptorProperty[_T]): self._create_descriptor() def instrument_class(self, mapper): - super(CompositeProperty, self).instrument_class(mapper) + super().instrument_class(mapper) self._setup_event_handlers() + def _composite_values_from_instance(self, value): + if self._generated_composite_accessor: + return self._generated_composite_accessor(value) + else: + try: + accessor = value.__composite_values__ + except AttributeError as ae: + raise sa_exc.InvalidRequestError( + f"Composite class {self.composite_class.__name__} is not " + f"a dataclass and does not define a __composite_values__()" + " method; can't get state" + ) from ae + else: + return accessor() + def do_init(self): - """Initialization which occurs after the :class:`.CompositeProperty` + """Initialization which occurs after the :class:`.Composite` has been associated with its parent mapper. """ @@ -181,7 +229,8 @@ class CompositeProperty(DescriptorProperty[_T]): setattr(instance, key, None) else: for key, value in zip( - self._attribute_keys, value.__composite_values__() + self._attribute_keys, + self._composite_values_from_instance(value), ): setattr(instance, key, value) @@ -196,18 +245,74 @@ class CompositeProperty(DescriptorProperty[_T]): self.descriptor = property(fget, fset, fdel) + @util.preload_module("sqlalchemy.orm.properties") + @util.preload_module("sqlalchemy.orm.decl_base") + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + MappedColumn = util.preloaded.orm_properties.MappedColumn + decl_base = util.preloaded.orm_decl_base + + argument = _extract_mapped_subtype( + annotation, + cls, + key, + MappedColumn, + self.composite_class is None, + is_dataclass_field, + ) + + if argument and self.composite_class is None: + if isinstance(argument, str) or hasattr( + argument, "__forward_arg__" + ): + raise sa_exc.ArgumentError( + f"Can't use forward ref {argument} for composite " + f"class argument" + ) + self.composite_class = argument + insp = inspect.signature(self.composite_class) + for param, attr in itertools.zip_longest( + insp.parameters.values(), self.attrs + ): + if param is None or attr is None: + raise sa_exc.ArgumentError( + f"number of arguments to {self.composite_class.__name__} " + f"class and number of attributes don't match" + ) + if isinstance(attr, MappedColumn): + attr.declarative_scan_for_composite( + registry, cls, key, param.name, param.annotation + ) + elif isinstance(attr, schema.Column): + decl_base._undefer_column_name(param.name, attr) + + if not hasattr(cls, "__composite_values__"): + getter = operator.attrgetter( + *[p.name for p in insp.parameters.values()] + ) + if len(insp.parameters) == 1: + self._generated_composite_accessor = lambda obj: (getter(obj),) + else: + self._generated_composite_accessor = getter + @util.memoized_property def _comparable_elements(self): return [getattr(self.parent.class_, prop.key) for prop in self.props] @util.memoized_property + @util.preload_module("orm.properties") def props(self): props = [] + MappedColumn = util.preloaded.orm_properties.MappedColumn + for attr in self.attrs: if isinstance(attr, str): prop = self.parent.get_property(attr, _configure_mappers=False) elif isinstance(attr, schema.Column): prop = self.parent._columntoproperty[attr] + elif isinstance(attr, MappedColumn): + prop = self.parent._columntoproperty[attr.column] elif isinstance(attr, attributes.InstrumentedAttribute): prop = attr.property else: @@ -220,8 +325,22 @@ class CompositeProperty(DescriptorProperty[_T]): return props @property + @util.preload_module("orm.properties") def columns(self): - return [a for a in self.attrs if isinstance(a, schema.Column)] + MappedColumn = util.preloaded.orm_properties.MappedColumn + return [ + a.column if isinstance(a, MappedColumn) else a + for a in self.attrs + if isinstance(a, (schema.Column, MappedColumn)) + ] + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + return self + + @property + def columns_to_assign(self) -> List[schema.Column]: + return [c for c in self.columns if c.table is None] def _setup_arguments_on_columns(self): """Propagate configuration arguments made on this composite @@ -351,9 +470,7 @@ class CompositeProperty(DescriptorProperty[_T]): class CompositeBundle(orm_util.Bundle): def __init__(self, property_, expr): self.property = property_ - super(CompositeProperty.CompositeBundle, self).__init__( - property_.key, *expr - ) + super().__init__(property_.key, *expr) def create_row_processor(self, query, procs, labels): def proc(row): @@ -365,7 +482,7 @@ class CompositeProperty(DescriptorProperty[_T]): class Comparator(PropComparator[_PT]): """Produce boolean, comparison, and other operators for - :class:`.CompositeProperty` attributes. + :class:`.Composite` attributes. See the example in :ref:`composite_operations` for an overview of usage , as well as the documentation for :class:`.PropComparator`. @@ -402,7 +519,7 @@ class CompositeProperty(DescriptorProperty[_T]): "proxy_key": self.prop.key, } ) - return CompositeProperty.CompositeBundle(self.prop, clauses) + return Composite.CompositeBundle(self.prop, clauses) def _bulk_update_tuples(self, value): if isinstance(value, sql.elements.BindParameter): @@ -411,7 +528,7 @@ class CompositeProperty(DescriptorProperty[_T]): if value is None: values = [None for key in self.prop._attribute_keys] elif isinstance(value, self.prop.composite_class): - values = value.__composite_values__() + values = self.prop._composite_values_from_instance(value) else: raise sa_exc.ArgumentError( "Can't UPDATE composite attribute %s to %r" @@ -434,7 +551,7 @@ class CompositeProperty(DescriptorProperty[_T]): if other is None: values = [None] * len(self.prop._comparable_elements) else: - values = other.__composite_values__() + values = self.prop._composite_values_from_instance(other) comparisons = [ a == b for a, b in zip(self.prop._comparable_elements, values) ] @@ -477,7 +594,7 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): return comparator_callable def __init__(self): - super(ConcreteInheritedProperty, self).__init__() + super().__init__() def warn(): raise AttributeError( @@ -502,7 +619,24 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): self.descriptor = NoninheritedConcreteProp() -class SynonymProperty(DescriptorProperty[_T]): +class Synonym(DescriptorProperty[_T]): + """Denote an attribute name as a synonym to a mapped property, + in that the attribute will mirror the value and expression behavior + of another attribute. + + :class:`.Synonym` is constructed using the :func:`_orm.synonym` + function. + + .. versionchanged:: 2.0 Renamed :class:`_orm.SynonymProperty` + to :class:`_orm.Synonym`. The old name + :class:`_orm.SynonymProperty` remains as an alias. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + """ + def __init__( self, name, @@ -512,7 +646,7 @@ class SynonymProperty(DescriptorProperty[_T]): doc=None, info=None, ): - super(SynonymProperty, self).__init__() + super().__init__() self.name = name self.map_column = map_column diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index ade47480d..3d9c61c20 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -28,7 +28,7 @@ from ..engine import result @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="dynamic") +@relationships.Relationship.strategy_for(lazy="dynamic") class DynaLoader(strategies.AbstractRelationshipLoader): def init_class_attribute(self, mapper): self.is_class_level = True diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index b9a5aaf51..1f9ec78f7 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -20,7 +20,12 @@ import collections import typing from typing import Any from typing import cast +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type from typing import TypeVar +from typing import Union from . import exc as orm_exc from . import path_registry @@ -41,8 +46,15 @@ from .. import util from ..sql import operators from ..sql import roles from ..sql import visitors +from ..sql._typing import _ColumnsClauseElement from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey +from ..sql.schema import Column +from ..sql.type_api import TypeEngine +from ..util.typing import TypedDict + +if typing.TYPE_CHECKING: + from .decl_api import RegistryType _T = TypeVar("_T", bound=Any) @@ -85,6 +97,54 @@ class ORMFromClauseRole(roles.StrictFromClauseRole): _role_name = "ORM mapped entity, aliased entity, or FROM expression" +class ORMColumnDescription(TypedDict): + name: str + type: Union[Type, TypeEngine] + aliased: bool + expr: _ColumnsClauseElement + entity: Optional[_ColumnsClauseElement] + + +class _IntrospectsAnnotations: + __slots__ = () + + def declarative_scan( + self, + registry: "RegistryType", + cls: type, + key: str, + annotation: Optional[type], + is_dataclass_field: Optional[bool], + ) -> None: + """Perform class-specific initializaton at early declarative scanning + time. + + .. versionadded:: 2.0 + + """ + + +class _MapsColumns(_MappedAttribute[_T]): + """interface for declarative-capable construct that delivers one or more + Column objects to the declarative process to be part of a Table. + """ + + __slots__ = () + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + """return a MapperProperty to be assigned to the declarative mapping""" + raise NotImplementedError() + + @property + def columns_to_assign(self) -> List[Column]: + """A list of Column objects that should be declaratively added to the + new Table object. + + """ + raise NotImplementedError() + + @inspection._self_inspects class MapperProperty( HasCacheKey, _MappedAttribute[_T], InspectionAttr, util.MemoizedSlots @@ -96,7 +156,7 @@ class MapperProperty( an instance of :class:`.ColumnProperty`, and a reference to another class produced by :func:`_orm.relationship`, represented in the mapping as an instance of - :class:`.RelationshipProperty`. + :class:`.Relationship`. """ @@ -118,7 +178,7 @@ class MapperProperty( This collection is checked before the 'cascade_iterator' method is called. - The collection typically only applies to a RelationshipProperty. + The collection typically only applies to a Relationship. """ @@ -132,7 +192,7 @@ class MapperProperty( def _links_to_entity(self): """True if this MapperProperty refers to a mapped entity. - Should only be True for RelationshipProperty, False for all others. + Should only be True for Relationship, False for all others. """ raise NotImplementedError() @@ -189,7 +249,7 @@ class MapperProperty( Note that the 'cascade' collection on this MapperProperty is checked first for the given type before cascade_iterator is called. - This method typically only applies to RelationshipProperty. + This method typically only applies to Relationship. """ @@ -323,7 +383,7 @@ class PropComparator( be redefined at both the Core and ORM level. :class:`.PropComparator` is the base class of operator redefinition for ORM-level operations, including those of :class:`.ColumnProperty`, - :class:`.RelationshipProperty`, and :class:`.CompositeProperty`. + :class:`.Relationship`, and :class:`.Composite`. User-defined subclasses of :class:`.PropComparator` may be created. The built-in Python comparison and math operator methods, such as @@ -339,19 +399,19 @@ class PropComparator( from sqlalchemy.orm.properties import \ ColumnProperty,\ - CompositeProperty,\ - RelationshipProperty + Composite,\ + Relationship class MyColumnComparator(ColumnProperty.Comparator): def __eq__(self, other): return self.__clause_element__() == other - class MyRelationshipComparator(RelationshipProperty.Comparator): + class MyRelationshipComparator(Relationship.Comparator): def any(self, expression): "define the 'any' operation" # ... - class MyCompositeComparator(CompositeProperty.Comparator): + class MyCompositeComparator(Composite.Comparator): def __gt__(self, other): "redefine the 'greater than' operation" @@ -386,9 +446,9 @@ class PropComparator( :class:`.ColumnProperty.Comparator` - :class:`.RelationshipProperty.Comparator` + :class:`.Relationship.Comparator` - :class:`.CompositeProperty.Comparator` + :class:`.Composite.Comparator` :class:`.ColumnOperators` @@ -552,7 +612,7 @@ class PropComparator( given criterion. The usual implementation of ``any()`` is - :meth:`.RelationshipProperty.Comparator.any`. + :meth:`.Relationship.Comparator.any`. :param criterion: an optional ClauseElement formulated against the member class' table or attributes. @@ -570,7 +630,7 @@ class PropComparator( given criterion. The usual implementation of ``has()`` is - :meth:`.RelationshipProperty.Comparator.has`. + :meth:`.Relationship.Comparator.has`. :param criterion: an optional ClauseElement formulated against the member class' table or attributes. @@ -606,10 +666,13 @@ class StrategizedProperty(MapperProperty[_T]): "strategy", "_wildcard_token", "_default_path_loader_key", + "strategy_key", ) inherit_cache = True strategy_wildcard_key = None + strategy_key: Tuple[Any, ...] + def _memoized_attr__wildcard_token(self): return ( f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}", diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py new file mode 100644 index 000000000..75abeef4c --- /dev/null +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -0,0 +1,232 @@ +# orm/collections.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import operator +from typing import Any +from typing import Callable +from typing import Dict +from typing import Type +from typing import TypeVar + +from . import base +from .collections import collection +from .. import exc as sa_exc +from .. import util +from ..sql import coercions +from ..sql import expression +from ..sql import roles + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +class _PlainColumnGetter: + """Plain column getter, stores collection of Column objects + directly. + + Serializes to a :class:`._SerializableColumnGetterV2` + which has more expensive __call__() performance + and some rare caveats. + + """ + + __slots__ = ("cols", "composite") + + def __init__(self, cols): + self.cols = cols + self.composite = len(cols) > 1 + + def __reduce__(self): + return _SerializableColumnGetterV2._reduce_from_cols(self.cols) + + def _cols(self, mapper): + return self.cols + + def __call__(self, value): + state = base.instance_state(value) + m = base._state_mapper(state) + + key = [ + m._get_state_attr_by_column(state, state.dict, col) + for col in self._cols(m) + ] + + if self.composite: + return tuple(key) + else: + return key[0] + + +class _SerializableColumnGetterV2(_PlainColumnGetter): + """Updated serializable getter which deals with + multi-table mapped classes. + + Two extremely unusual cases are not supported. + Mappings which have tables across multiple metadata + objects, or which are mapped to non-Table selectables + linked across inheriting mappers may fail to function + here. + + """ + + __slots__ = ("colkeys",) + + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return self.__class__, (self.colkeys,) + + @classmethod + def _reduce_from_cols(cls, cols): + def _table_key(c): + if not isinstance(c.table, expression.TableClause): + return None + else: + return c.table.key + + colkeys = [(c.key, _table_key(c)) for c in cols] + return _SerializableColumnGetterV2, (colkeys,) + + def _cols(self, mapper): + cols = [] + metadata = getattr(mapper.local_table, "metadata", None) + for (ckey, tkey) in self.colkeys: + if tkey is None or metadata is None or tkey not in metadata: + cols.append(mapper.local_table.c[ckey]) + else: + cols.append(metadata.tables[tkey].c[ckey]) + return cols + + +def column_mapped_collection(mapping_spec): + """A dictionary-based collection type with column-based keying. + + Returns a :class:`.MappedCollection` factory with a keying function + generated from mapping_spec, which may be a Column or a sequence + of Columns. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + cols = [ + coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") + for q in util.to_list(mapping_spec) + ] + keyfunc = _PlainColumnGetter(cols) + return _mapped_collection_cls(keyfunc) + + +def attribute_mapped_collection(attr_name: str) -> Type["MappedCollection"]: + """A dictionary-based collection type with attribute-based keying. + + Returns a :class:`.MappedCollection` factory with a keying based on the + 'attr_name' attribute of entities in the collection, where ``attr_name`` + is the string name of the attribute. + + .. warning:: the key value must be assigned to its final value + **before** it is accessed by the attribute mapped collection. + Additionally, changes to the key attribute are **not tracked** + automatically, which means the key in the dictionary is not + automatically synchronized with the key value on the target object + itself. See the section :ref:`key_collections_mutations` + for an example. + + """ + getter = operator.attrgetter(attr_name) + return _mapped_collection_cls(getter) + + +def mapped_collection( + keyfunc: Callable[[Any], _KT] +) -> Type["MappedCollection[_KT, Any]"]: + """A dictionary-based collection type with arbitrary keying. + + Returns a :class:`.MappedCollection` factory with a keying function + generated from keyfunc, a callable that takes an entity and returns a + key value. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + return _mapped_collection_cls(keyfunc) + + +class MappedCollection(Dict[_KT, _VT]): + """A basic dictionary-based collection class. + + Extends dict with the minimal bag semantics that collection + classes require. ``set`` and ``remove`` are implemented in terms + of a keying function: any callable that takes an object and + returns an object for use as a dictionary key. + + """ + + def __init__(self, keyfunc): + """Create a new collection with keying provided by keyfunc. + + keyfunc may be any callable that takes an object and returns an object + for use as a dictionary key. + + The keyfunc will be called every time the ORM needs to add a member by + value-only (such as when loading instances from the database) or + remove a member. The usual cautions about dictionary keying apply- + ``keyfunc(object)`` should return the same output for the life of the + collection. Keying based on mutable properties can result in + unreachable instances "lost" in the collection. + + """ + self.keyfunc = keyfunc + + @classmethod + def _unreduce(cls, keyfunc, values): + mp = MappedCollection(keyfunc) + mp.update(values) + return mp + + def __reduce__(self): + return (MappedCollection._unreduce, (self.keyfunc, dict(self))) + + @collection.appender + @collection.internally_instrumented + def set(self, value, _sa_initiator=None): + """Add an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + self.__setitem__(key, value, _sa_initiator) + + @collection.remover + @collection.internally_instrumented + def remove(self, value, _sa_initiator=None): + """Remove an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + # Let self[key] raise if key is not in this collection + # testlib.pragma exempt:__ne__ + if self[key] != value: + raise sa_exc.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the MappedCollection key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % (value, self[key], key) + ) + self.__delitem__(key, _sa_initiator) + + +def _mapped_collection_cls(keyfunc): + class _MKeyfuncMapped(MappedCollection): + def __init__(self): + super().__init__(keyfunc) + + return _MKeyfuncMapped diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index fdf065488..cd0d1e820 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -580,7 +580,16 @@ class Mapper( self.version_id_prop = version_id_col self.version_id_col = None else: - self.version_id_col = version_id_col + self.version_id_col = ( + coercions.expect( + roles.ColumnArgumentOrKeyRole, + version_id_col, + argname="version_id_col", + ) + if version_id_col is not None + else None + ) + if version_id_generator is False: self.version_id_generator = False elif version_id_generator is None: @@ -2473,7 +2482,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") def synonyms(self): - """Return a namespace of all :class:`.SynonymProperty` + """Return a namespace of all :class:`.Synonym` properties maintained by this :class:`_orm.Mapper`. .. seealso:: @@ -2485,7 +2494,7 @@ class Mapper( """ descriptor_props = util.preloaded.orm_descriptor_props - return self._filter_properties(descriptor_props.SynonymProperty) + return self._filter_properties(descriptor_props.Synonym) @property def entity_namespace(self): @@ -2508,7 +2517,7 @@ class Mapper( @util.preload_module("sqlalchemy.orm.relationships") @HasMemoized.memoized_attribute def relationships(self): - """A namespace of all :class:`.RelationshipProperty` properties + """A namespace of all :class:`.Relationship` properties maintained by this :class:`_orm.Mapper`. .. warning:: @@ -2531,13 +2540,13 @@ class Mapper( """ return self._filter_properties( - util.preloaded.orm_relationships.RelationshipProperty + util.preloaded.orm_relationships.Relationship ) @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") def composites(self): - """Return a namespace of all :class:`.CompositeProperty` + """Return a namespace of all :class:`.Composite` properties maintained by this :class:`_orm.Mapper`. .. seealso:: @@ -2548,7 +2557,7 @@ class Mapper( """ return self._filter_properties( - util.preloaded.orm_descriptor_props.CompositeProperty + util.preloaded.orm_descriptor_props.Composite ) def _filter_properties(self, type_): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index b035dbef2..f28c45fab 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -13,37 +13,60 @@ mapped attributes. """ from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Set from typing import TypeVar from . import attributes from . import strategy_options -from .descriptor_props import CompositeProperty +from .base import SQLCoreOperations +from .descriptor_props import Composite from .descriptor_props import ConcreteInheritedProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Synonym +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns +from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty -from .relationships import RelationshipProperty +from .relationships import Relationship +from .util import _extract_mapped_subtype from .util import _orm_full_deannotate +from .. import exc as sa_exc +from .. import ForeignKey from .. import log from .. import sql from .. import util from ..sql import coercions +from ..sql import operators from ..sql import roles +from ..sql import sqltypes +from ..sql.schema import Column +from ..util.typing import de_optionalize_union_types +from ..util.typing import de_stringify_annotation +from ..util.typing import is_fwd_ref +from ..util.typing import NoneType _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) __all__ = [ "ColumnProperty", - "CompositeProperty", + "Composite", "ConcreteInheritedProperty", - "RelationshipProperty", - "SynonymProperty", + "Relationship", + "Synonym", ] @log.class_logger -class ColumnProperty(StrategizedProperty[_T]): +class ColumnProperty( + _MapsColumns[_T], + StrategizedProperty[_T], + _IntrospectsAnnotations, + log.Identified, +): """Describes an object attribute that corresponds to a table column. Public constructor is the :func:`_orm.column_property` function. @@ -65,7 +88,6 @@ class ColumnProperty(StrategizedProperty[_T]): "active_history", "expire_on_flush", "doc", - "strategy_key", "_creation_order", "_is_polymorphic_discriminator", "_mapped_by_synonym", @@ -84,8 +106,8 @@ class ColumnProperty(StrategizedProperty[_T]): coercions.expect(roles.LabeledColumnExprRole, c) for c in columns ] self.columns = [ - coercions.expect( - roles.LabeledColumnExprRole, _orm_full_deannotate(c) + _orm_full_deannotate( + coercions.expect(roles.LabeledColumnExprRole, c) ) for c in columns ] @@ -130,6 +152,27 @@ class ColumnProperty(StrategizedProperty[_T]): if self.raiseload: self.strategy_key += (("raiseload", True),) + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + column = self.columns[0] + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + return self + + @property + def columns_to_assign(self) -> List[Column]: + return [ + c + for c in self.columns + if isinstance(c, Column) and c.table is None + ] + def _memoized_attr__renders_in_subqueries(self): return ("deferred", True) not in self.strategy_key or ( self not in self.parent._readonly_props @@ -197,7 +240,7 @@ class ColumnProperty(StrategizedProperty[_T]): ) def do_init(self): - super(ColumnProperty, self).do_init() + super().do_init() if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( self.columns @@ -364,3 +407,135 @@ class ColumnProperty(StrategizedProperty[_T]): if not self.parent or not self.key: return object.__repr__(self) return str(self.parent.class_.__name__) + "." + self.key + + +class MappedColumn( + SQLCoreOperations[_T], + operators.ColumnOperators[SQLCoreOperations], + _IntrospectsAnnotations, + _MapsColumns[_T], +): + """Maps a single :class:`_schema.Column` on a class. + + :class:`_orm.MappedColumn` is a specialization of the + :class:`_orm.ColumnProperty` class and is oriented towards declarative + configuration. + + To construct :class:`_orm.MappedColumn` objects, use the + :func:`_orm.mapped_column` constructor function. + + .. versionadded:: 2.0 + + + """ + + __slots__ = ( + "column", + "_creation_order", + "foreign_keys", + "_has_nullable", + "deferred", + ) + + deferred: bool + column: Column[_T] + foreign_keys: Optional[Set[ForeignKey]] + + def __init__(self, *arg, **kw): + self.deferred = kw.pop("deferred", False) + self.column = cast("Column[_T]", Column(*arg, **kw)) + self.foreign_keys = self.column.foreign_keys + self._has_nullable = "nullable" in kw + util.set_creation_order(self) + + def _copy(self, **kw): + new = self.__class__.__new__(self.__class__) + new.column = self.column._copy(**kw) + new.deferred = self.deferred + new.foreign_keys = new.column.foreign_keys + new._has_nullable = self._has_nullable + util.set_creation_order(new) + return new + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + if self.deferred: + return ColumnProperty(self.column, deferred=True) + else: + return None + + @property + def columns_to_assign(self) -> List[Column]: + return [self.column] + + def __clause_element__(self): + return self.column + + def operate(self, op, *other, **kwargs): + return op(self.__clause_element__(), *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) + + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + column = self.column + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + sqltype = column.type + + argument = _extract_mapped_subtype( + annotation, + cls, + key, + MappedColumn, + sqltype._isnull and not self.column.foreign_keys, + is_dataclass_field, + ) + if argument is None: + return + + self._init_column_for_annotation(cls, registry, argument) + + @util.preload_module("sqlalchemy.orm.decl_base") + def declarative_scan_for_composite( + self, registry, cls, key, param_name, param_annotation + ): + decl_base = util.preloaded.orm_decl_base + decl_base._undefer_column_name(param_name, self.column) + self._init_column_for_annotation(cls, registry, param_annotation) + + def _init_column_for_annotation(self, cls, registry, argument): + sqltype = self.column.type + + nullable = False + + if hasattr(argument, "__origin__"): + nullable = NoneType in argument.__args__ + + if not self._has_nullable: + self.column.nullable = nullable + + if sqltype._isnull and not self.column.foreign_keys: + sqltype = None + our_type = de_optionalize_union_types(argument) + + if is_fwd_ref(our_type): + our_type = de_stringify_annotation(cls, our_type) + + if registry.type_annotation_map: + sqltype = registry.type_annotation_map.get(our_type) + if sqltype is None: + sqltype = sqltypes._type_map_get(our_type) + + if sqltype is None: + raise sa_exc.ArgumentError( + f"Could not locate SQLAlchemy Core " + f"type for Python type: {our_type}" + ) + self.column.type = sqltype diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 15259f130..61174487a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -21,7 +21,12 @@ database to return iterable result sets. import collections.abc as collections_abc import itertools import operator -import typing +from typing import Any +from typing import Generic +from typing import Iterable +from typing import List +from typing import Optional +from typing import TypeVar from . import exc as orm_exc from . import interfaces @@ -35,8 +40,9 @@ from .context import LABEL_STYLE_LEGACY_ORM from .context import ORMCompileState from .context import ORMFromStatementCompileState from .context import QueryContext +from .interfaces import ORMColumnDescription from .interfaces import ORMColumnsClauseRole -from .util import aliased +from .util import AliasedClass from .util import object_mapper from .util import with_parent from .. import exc as sa_exc @@ -45,16 +51,19 @@ from .. import inspection from .. import log from .. import sql from .. import util +from ..engine import Result from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import Select from ..sql import util as sql_util from ..sql import visitors +from ..sql._typing import _FromClauseElement from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative from ..sql.base import Executable +from ..sql.expression import Exists from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg @@ -67,9 +76,12 @@ from ..sql.selectable import SelectBase from ..sql.selectable import SelectStatementGrouping from ..sql.visitors import InternalTraversal -__all__ = ["Query", "QueryContext", "aliased"] -SelfQuery = typing.TypeVar("SelfQuery", bound="Query") +__all__ = ["Query", "QueryContext"] + +_T = TypeVar("_T", bound=Any) + +SelfQuery = TypeVar("SelfQuery", bound="Query") @inspection._self_inspects @@ -80,7 +92,9 @@ class Query( HasPrefixes, HasSuffixes, HasHints, + log.Identified, Executable, + Generic[_T], ): """ORM-level SQL construction object. @@ -1040,7 +1054,7 @@ class Query( for prop in mapper.iterate_properties: if ( - isinstance(prop, relationships.RelationshipProperty) + isinstance(prop, relationships.Relationship) and prop.mapper is entity_zero.mapper ): property = prop # noqa @@ -1064,7 +1078,7 @@ class Query( if alias is not None: # TODO: deprecate - entity = aliased(entity, alias) + entity = AliasedClass(entity, alias) self._raw_columns = list(self._raw_columns) @@ -1992,7 +2006,9 @@ class Query( @_generative @_assertions(_no_clauseelement_condition) - def select_from(self: SelfQuery, *from_obj) -> SelfQuery: + def select_from( + self: SelfQuery, *from_obj: _FromClauseElement + ) -> SelfQuery: r"""Set the FROM clause of this :class:`.Query` explicitly. :meth:`.Query.select_from` is often used in conjunction with @@ -2144,7 +2160,7 @@ class Query( self._distinct = True return self - def all(self): + def all(self) -> List[_T]: """Return the results represented by this :class:`_query.Query` as a list. @@ -2183,7 +2199,7 @@ class Query( self._statement = statement return self - def first(self): + def first(self) -> Optional[_T]: """Return the first result of this ``Query`` or None if the result doesn't contain any row. @@ -2209,7 +2225,7 @@ class Query( else: return self.limit(1)._iter().first() - def one_or_none(self): + def one_or_none(self) -> Optional[_T]: """Return at most one result or raise an exception. Returns ``None`` if the query selects @@ -2235,7 +2251,7 @@ class Query( """ return self._iter().one_or_none() - def one(self): + def one(self) -> _T: """Return exactly one result or raise an exception. Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects @@ -2255,7 +2271,7 @@ class Query( """ return self._iter().one() - def scalar(self): + def scalar(self) -> Any: """Return the first element of the first result or None if no rows present. If multiple rows are returned, raises MultipleResultsFound. @@ -2283,7 +2299,7 @@ class Query( except orm_exc.NoResultFound: return None - def __iter__(self): + def __iter__(self) -> Iterable[_T]: return self._iter().__iter__() def _iter(self): @@ -2309,7 +2325,7 @@ class Query( return result - def __str__(self): + def __str__(self) -> str: statement = self._statement_20() try: @@ -2327,7 +2343,7 @@ class Query( return fn(clause=statement, **kw) @property - def column_descriptions(self): + def column_descriptions(self) -> List[ORMColumnDescription]: """Return metadata about the columns which would be returned by this :class:`_query.Query`. @@ -2368,7 +2384,7 @@ class Query( return _column_descriptions(self, legacy=True) - def instances(self, result_proxy, context=None): + def instances(self, result_proxy: Result, context=None) -> Any: """Return an ORM result given a :class:`_engine.CursorResult` and :class:`.QueryContext`. @@ -2400,6 +2416,7 @@ class Query( if result._attributes.get("filtered", False): result = result.unique() + # TODO: isn't this supposed to be a list? return result @util.became_legacy_20( @@ -2436,7 +2453,7 @@ class Query( return loading.merge_result(self, iterator, load) - def exists(self): + def exists(self) -> Exists: """A convenience method that turns a query into an EXISTS subquery of the form EXISTS (SELECT 1 FROM ... WHERE ...). diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index c5ea07051..1b8f778c0 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -13,10 +13,15 @@ SQL annotation and aliasing behavior focused on the `primaryjoin` and `secondaryjoin` aspects of :func:`_orm.relationship`. """ +from __future__ import annotations + import collections +from collections import abc import re +import typing from typing import Any from typing import Callable +from typing import Optional from typing import Type from typing import TypeVar from typing import Union @@ -26,11 +31,13 @@ from . import attributes from . import strategy_options from .base import _is_mapped_class from .base import state_str +from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY from .interfaces import MANYTOONE from .interfaces import ONETOMANY from .interfaces import PropComparator from .interfaces import StrategizedProperty +from .util import _extract_mapped_subtype from .util import _orm_annotate from .util import _orm_deannotate from .util import CascadeOptions @@ -53,10 +60,26 @@ from ..sql.util import join_condition from ..sql.util import selectables_overlap from ..sql.util import visit_binary_product +if typing.TYPE_CHECKING: + from .mapper import Mapper + from .util import AliasedClass + from .util import AliasedInsp + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) +_RelationshipArgumentType = Union[ + str, + Type[_T], + Callable[[], Type[_T]], + "Mapper[_T]", + "AliasedClass[_T]", + Callable[[], "Mapper[_T]"], + Callable[[], "AliasedClass[_T]"], +] + + def remote(expr): """Annotate a portion of a primaryjoin expression with a 'remote' annotation. @@ -97,7 +120,9 @@ def foreign(expr): @log.class_logger -class RelationshipProperty(StrategizedProperty[_T]): +class Relationship( + _IntrospectsAnnotations, StrategizedProperty[_T], log.Identified +): """Describes an object property that holds a single item or list of items that correspond to a related database table. @@ -107,6 +132,10 @@ class RelationshipProperty(StrategizedProperty[_T]): :ref:`relationship_config_toplevel` + .. versionchanged:: 2.0 Renamed :class:`_orm.RelationshipProperty` + to :class:`_orm.Relationship`. The old name + :class:`_orm.RelationshipProperty` remains as an alias. + """ strategy_wildcard_key = strategy_options._RELATIONSHIP_TOKEN @@ -126,7 +155,7 @@ class RelationshipProperty(StrategizedProperty[_T]): def __init__( self, - argument: Union[str, Type[_T], Callable[[], Type[_T]]], + argument: Optional[_RelationshipArgumentType[_T]] = None, secondary=None, primaryjoin=None, secondaryjoin=None, @@ -162,7 +191,7 @@ class RelationshipProperty(StrategizedProperty[_T]): sync_backref=None, _legacy_inactive_history_style=False, ): - super(RelationshipProperty, self).__init__() + super(Relationship, self).__init__() self.uselist = uselist self.argument = argument @@ -221,9 +250,7 @@ class RelationshipProperty(StrategizedProperty[_T]): self.local_remote_pairs = _local_remote_pairs self.bake_queries = bake_queries self.load_on_pending = load_on_pending - self.comparator_factory = ( - comparator_factory or RelationshipProperty.Comparator - ) + self.comparator_factory = comparator_factory or Relationship.Comparator self.comparator = self.comparator_factory(self, None) util.set_creation_order(self) @@ -288,7 +315,7 @@ class RelationshipProperty(StrategizedProperty[_T]): class Comparator(PropComparator[_PT]): """Produce boolean, comparison, and other operators for - :class:`.RelationshipProperty` attributes. + :class:`.Relationship` attributes. See the documentation for :class:`.PropComparator` for a brief overview of ORM level operator definition. @@ -318,7 +345,7 @@ class RelationshipProperty(StrategizedProperty[_T]): of_type=None, extra_criteria=(), ): - """Construction of :class:`.RelationshipProperty.Comparator` + """Construction of :class:`.Relationship.Comparator` is internal to the ORM's attribute mechanics. """ @@ -340,7 +367,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @util.memoized_property def entity(self): """The target entity referred to by this - :class:`.RelationshipProperty.Comparator`. + :class:`.Relationship.Comparator`. This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp` object. @@ -360,7 +387,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @util.memoized_property def mapper(self): """The target :class:`_orm.Mapper` referred to by this - :class:`.RelationshipProperty.Comparator`. + :class:`.Relationship.Comparator`. This is the "target" or "remote" side of the :func:`_orm.relationship`. @@ -411,7 +438,7 @@ class RelationshipProperty(StrategizedProperty[_T]): """ - return RelationshipProperty.Comparator( + return Relationship.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, @@ -427,7 +454,7 @@ class RelationshipProperty(StrategizedProperty[_T]): .. versionadded:: 1.4 """ - return RelationshipProperty.Comparator( + return Relationship.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, @@ -468,7 +495,7 @@ class RelationshipProperty(StrategizedProperty[_T]): many-to-one comparisons: * Comparisons against collections are not supported. - Use :meth:`~.RelationshipProperty.Comparator.contains`. + Use :meth:`~.Relationship.Comparator.contains`. * Compared to a scalar one-to-many, will produce a clause that compares the target columns in the parent to the given target. @@ -479,7 +506,7 @@ class RelationshipProperty(StrategizedProperty[_T]): queries that go beyond simple AND conjunctions of comparisons, such as those which use OR. Use explicit joins, outerjoins, or - :meth:`~.RelationshipProperty.Comparator.has` for + :meth:`~.Relationship.Comparator.has` for more comprehensive non-many-to-one scalar membership tests. * Comparisons against ``None`` given in a one-to-many @@ -613,12 +640,12 @@ class RelationshipProperty(StrategizedProperty[_T]): EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id AND related.x=2) - Because :meth:`~.RelationshipProperty.Comparator.any` uses + Because :meth:`~.Relationship.Comparator.any` uses a correlated subquery, its performance is not nearly as good when compared against large target tables as that of using a join. - :meth:`~.RelationshipProperty.Comparator.any` is particularly + :meth:`~.Relationship.Comparator.any` is particularly useful for testing for empty collections:: session.query(MyClass).filter( @@ -631,10 +658,10 @@ class RelationshipProperty(StrategizedProperty[_T]): NOT (EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id)) - :meth:`~.RelationshipProperty.Comparator.any` is only + :meth:`~.Relationship.Comparator.any` is only valid for collections, i.e. a :func:`_orm.relationship` that has ``uselist=True``. For scalar references, - use :meth:`~.RelationshipProperty.Comparator.has`. + use :meth:`~.Relationship.Comparator.has`. """ if not self.property.uselist: @@ -662,15 +689,15 @@ class RelationshipProperty(StrategizedProperty[_T]): EXISTS (SELECT 1 FROM related WHERE related.id==my_table.related_id AND related.x=2) - Because :meth:`~.RelationshipProperty.Comparator.has` uses + Because :meth:`~.Relationship.Comparator.has` uses a correlated subquery, its performance is not nearly as good when compared against large target tables as that of using a join. - :meth:`~.RelationshipProperty.Comparator.has` is only + :meth:`~.Relationship.Comparator.has` is only valid for scalar references, i.e. a :func:`_orm.relationship` that has ``uselist=False``. For collection references, - use :meth:`~.RelationshipProperty.Comparator.any`. + use :meth:`~.Relationship.Comparator.any`. """ if self.property.uselist: @@ -683,7 +710,7 @@ class RelationshipProperty(StrategizedProperty[_T]): """Return a simple expression that tests a collection for containment of a particular item. - :meth:`~.RelationshipProperty.Comparator.contains` is + :meth:`~.Relationship.Comparator.contains` is only valid for a collection, i.e. a :func:`_orm.relationship` that implements one-to-many or many-to-many with ``uselist=True``. @@ -700,12 +727,12 @@ class RelationshipProperty(StrategizedProperty[_T]): Where ``<some id>`` is the value of the foreign key attribute on ``other`` which refers to the primary key of its parent object. From this it follows that - :meth:`~.RelationshipProperty.Comparator.contains` is + :meth:`~.Relationship.Comparator.contains` is very useful when used with simple one-to-many operations. For many-to-many operations, the behavior of - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` has more caveats. The association table will be rendered in the statement, producing an "implicit" join, that is, includes multiple tables in the FROM @@ -722,14 +749,14 @@ class RelationshipProperty(StrategizedProperty[_T]): Where ``<some id>`` would be the primary key of ``other``. From the above, it is clear that - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` will **not** work with many-to-many collections when used in queries that move beyond simple AND conjunctions, such as multiple - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` expressions joined by OR. In such cases subqueries or explicit "outer joins" will need to be used instead. - See :meth:`~.RelationshipProperty.Comparator.any` for + See :meth:`~.Relationship.Comparator.any` for a less-performant alternative using EXISTS, or refer to :meth:`_query.Query.outerjoin` as well as :ref:`ormtutorial_joins` @@ -818,7 +845,7 @@ class RelationshipProperty(StrategizedProperty[_T]): * Comparisons against collections are not supported. Use - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` in conjunction with :func:`_expression.not_`. * Compared to a scalar one-to-many, will produce a clause that compares the target columns in the parent to @@ -830,7 +857,7 @@ class RelationshipProperty(StrategizedProperty[_T]): queries that go beyond simple AND conjunctions of comparisons, such as those which use OR. Use explicit joins, outerjoins, or - :meth:`~.RelationshipProperty.Comparator.has` in + :meth:`~.Relationship.Comparator.has` in conjunction with :func:`_expression.not_` for more comprehensive non-many-to-one scalar membership tests. @@ -1249,7 +1276,7 @@ class RelationshipProperty(StrategizedProperty[_T]): def _add_reverse_property(self, key): other = self.mapper.get_property(key, _configure_mappers=False) - if not isinstance(other, RelationshipProperty): + if not isinstance(other, Relationship): raise sa_exc.InvalidRequestError( "back_populates on relationship '%s' refers to attribute '%s' " "that is not a relationship. The back_populates parameter " @@ -1269,6 +1296,8 @@ class RelationshipProperty(StrategizedProperty[_T]): self._reverse_property.add(other) other._reverse_property.add(self) + other._setup_entity() + if not other.mapper.common_parent(self.parent): raise sa_exc.ArgumentError( "reverse_property %r on " @@ -1289,48 +1318,18 @@ class RelationshipProperty(StrategizedProperty[_T]): ) @util.memoized_property - @util.preload_module("sqlalchemy.orm.mapper") - def entity(self): + def entity(self) -> Union["Mapper", "AliasedInsp"]: """Return the target mapped entity, which is an inspect() of the class or aliased class that is referred towards. """ - - mapperlib = util.preloaded.orm_mapper - - if isinstance(self.argument, str): - argument = self._clsregistry_resolve_name(self.argument)() - - elif callable(self.argument) and not isinstance( - self.argument, (type, mapperlib.Mapper) - ): - argument = self.argument() - else: - argument = self.argument - - if isinstance(argument, type): - return mapperlib.class_mapper(argument, configure=False) - - try: - entity = inspect(argument) - except sa_exc.NoInspectionAvailable: - pass - else: - if hasattr(entity, "mapper"): - return entity - - raise sa_exc.ArgumentError( - "relationship '%s' expects " - "a class or a mapper argument (received: %s)" - % (self.key, type(argument)) - ) + self.parent._check_configure() + return self.entity @util.memoized_property - def mapper(self): + def mapper(self) -> "Mapper": """Return the targeted :class:`_orm.Mapper` for this - :class:`.RelationshipProperty`. - - This is a lazy-initializing static attribute. + :class:`.Relationship`. """ return self.entity.mapper @@ -1338,13 +1337,14 @@ class RelationshipProperty(StrategizedProperty[_T]): def do_init(self): self._check_conflicts() self._process_dependent_arguments() + self._setup_entity() self._setup_registry_dependencies() self._setup_join_conditions() self._check_cascade_settings(self._cascade) self._post_init() self._generate_backref() self._join_condition._warn_for_conflicting_sync_targets() - super(RelationshipProperty, self).do_init() + super(Relationship, self).do_init() self._lazy_strategy = self._get_strategy((("lazy", "select"),)) def _setup_registry_dependencies(self): @@ -1432,6 +1432,84 @@ class RelationshipProperty(StrategizedProperty[_T]): for x in util.to_column_set(self.remote_side) ) + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + argument = _extract_mapped_subtype( + annotation, + cls, + key, + Relationship, + self.argument is None, + is_dataclass_field, + ) + if argument is None: + return + + if hasattr(argument, "__origin__"): + + collection_class = argument.__origin__ + if issubclass(collection_class, abc.Collection): + if self.collection_class is None: + self.collection_class = collection_class + else: + self.uselist = False + if argument.__args__: + if issubclass(argument.__origin__, typing.Mapping): + type_arg = argument.__args__[1] + else: + type_arg = argument.__args__[0] + if hasattr(type_arg, "__forward_arg__"): + str_argument = type_arg.__forward_arg__ + argument = str_argument + else: + argument = type_arg + else: + raise sa_exc.ArgumentError( + f"Generic alias {argument} requires an argument" + ) + elif hasattr(argument, "__forward_arg__"): + argument = argument.__forward_arg__ + + self.argument = argument + + @util.preload_module("sqlalchemy.orm.mapper") + def _setup_entity(self, __argument=None): + if "entity" in self.__dict__: + return + + mapperlib = util.preloaded.orm_mapper + + if __argument: + argument = __argument + else: + argument = self.argument + + if isinstance(argument, str): + argument = self._clsregistry_resolve_name(argument)() + elif callable(argument) and not isinstance( + argument, (type, mapperlib.Mapper) + ): + argument = argument() + else: + argument = argument + + if isinstance(argument, type): + entity = mapperlib.class_mapper(argument, configure=False) + else: + try: + entity = inspect(argument) + except sa_exc.NoInspectionAvailable: + entity = None + + if not hasattr(entity, "mapper"): + raise sa_exc.ArgumentError( + "relationship '%s' expects " + "a class or a mapper argument (received: %s)" + % (self.key, type(argument)) + ) + + self.entity = entity # type: ignore self.target = self.entity.persist_selectable def _setup_join_conditions(self): @@ -1502,7 +1580,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @property def cascade(self): """Return the current cascade setting for this - :class:`.RelationshipProperty`. + :class:`.Relationship`. """ return self._cascade @@ -1666,7 +1744,7 @@ class RelationshipProperty(StrategizedProperty[_T]): kwargs.setdefault("passive_updates", self.passive_updates) kwargs.setdefault("sync_backref", self.sync_backref) self.back_populates = backref_key - relationship = RelationshipProperty( + relationship = Relationship( parent, self.secondary, pj, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index cf47ee729..6911ab505 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -9,6 +9,15 @@ import contextlib import itertools import sys +import typing +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import Union import weakref from . import attributes @@ -20,12 +29,15 @@ from . import persistence from . import query from . import state as statelib from .base import _class_to_mapper +from .base import _IdentityKeyType from .base import _none_set from .base import _state_mapper from .base import instance_str from .base import object_mapper from .base import object_state from .base import state_str +from .query import Query +from .state import InstanceState from .state_changes import _StateChange from .state_changes import _StateChangeState from .state_changes import _StateChangeStates @@ -34,14 +46,26 @@ from .. import engine from .. import exc as sa_exc from .. import sql from .. import util +from ..engine import Connection +from ..engine import Engine from ..engine.util import TransactionalContext from ..inspection import inspect from ..sql import coercions from ..sql import dml from ..sql import roles from ..sql import visitors +from ..sql._typing import _ColumnsClauseElement from ..sql.base import CompileState from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .mapper import Mapper + from ..engine import Row + from ..sql._typing import _ExecuteOptions + from ..sql._typing import _ExecuteParams + from ..sql.base import Executable + from ..sql.schema import Table __all__ = [ "Session", @@ -78,23 +102,60 @@ class _SessionClassMethods: "removed in a future release. Please refer to " ":func:`.session.close_all_sessions`.", ) - def close_all(cls): + def close_all(cls) -> None: """Close *all* sessions in memory.""" close_all_sessions() @classmethod + @overload + def identity_key( + cls, + class_: type, + ident: Tuple[Any, ...], + *, + identity_token: Optional[str], + ) -> _IdentityKeyType: + ... + + @classmethod + @overload + def identity_key(cls, *, instance: Any) -> _IdentityKeyType: + ... + + @classmethod + @overload + def identity_key( + cls, class_: type, *, row: "Row", identity_token: Optional[str] + ) -> _IdentityKeyType: + ... + + @classmethod @util.preload_module("sqlalchemy.orm.util") - def identity_key(cls, *args, **kwargs): + def identity_key( + cls, + class_=None, + ident=None, + *, + instance=None, + row=None, + identity_token=None, + ) -> _IdentityKeyType: """Return an identity key. This is an alias of :func:`.util.identity_key`. """ - return util.preloaded.orm_util.identity_key(*args, **kwargs) + return util.preloaded.orm_util.identity_key( + class_, + ident, + instance=instance, + row=row, + identity_token=identity_token, + ) @classmethod - def object_session(cls, instance): + def object_session(cls, instance: Any) -> "Session": """Return the :class:`.Session` to which an object belongs. This is an alias of :func:`.object_session`. @@ -142,15 +203,26 @@ class ORMExecuteState(util.MemoizedSlots): "_update_execution_options", ) + session: "Session" + statement: "Executable" + parameters: "_ExecuteParams" + execution_options: "_ExecuteOptions" + local_execution_options: "_ExecuteOptions" + bind_arguments: Dict[str, Any] + _compile_state_cls: Type[context.ORMCompileState] + _starting_event_idx: Optional[int] + _events_todo: List[Any] + _update_execution_options: Optional["_ExecuteOptions"] + def __init__( self, - session, - statement, - parameters, - execution_options, - bind_arguments, - compile_state_cls, - events_todo, + session: "Session", + statement: "Executable", + parameters: "_ExecuteParams", + execution_options: "_ExecuteOptions", + bind_arguments: Dict[str, Any], + compile_state_cls: Type[context.ORMCompileState], + events_todo: List[Any], ): self.session = session self.statement = statement @@ -834,7 +906,7 @@ class SessionTransaction(_StateChange, TransactionalContext): (SessionTransactionState.ACTIVE, SessionTransactionState.PREPARED), SessionTransactionState.CLOSED, ) - def commit(self, _to_root=False): + def commit(self, _to_root: bool = False) -> None: if self._state is not SessionTransactionState.PREPARED: with self._expect_state(SessionTransactionState.PREPARED): self._prepare_impl() @@ -981,18 +1053,42 @@ class Session(_SessionClassMethods): _is_asyncio = False + identity_map: identity.IdentityMap + _new: Dict["InstanceState", Any] + _deleted: Dict["InstanceState", Any] + bind: Optional[Union[Engine, Connection]] + __binds: Dict[ + Union[type, "Mapper", "Table"], + Union[engine.Engine, engine.Connection], + ] + _flusing: bool + _warn_on_events: bool + _transaction: Optional[SessionTransaction] + _nested_transaction: Optional[SessionTransaction] + hash_key: int + autoflush: bool + expire_on_commit: bool + enable_baked_queries: bool + twophase: bool + _query_cls: Type[Query] + def __init__( self, - bind=None, - autoflush=True, - future=True, - expire_on_commit=True, - twophase=False, - binds=None, - enable_baked_queries=True, - info=None, - query_cls=None, - autocommit=False, + bind: Optional[Union[engine.Engine, engine.Connection]] = None, + autoflush: bool = True, + future: Literal[True] = True, + expire_on_commit: bool = True, + twophase: bool = False, + binds: Optional[ + Dict[ + Union[type, "Mapper", "Table"], + Union[engine.Engine, engine.Connection], + ] + ] = None, + enable_baked_queries: bool = True, + info: Optional[Dict[Any, Any]] = None, + query_cls: Optional[Type[query.Query]] = None, + autocommit: Literal[False] = False, ): r"""Construct a new Session. @@ -1054,7 +1150,8 @@ class Session(_SessionClassMethods): :class:`.sessionmaker` function, and is not sent directly to the constructor for ``Session``. - :param enable_baked_queries: defaults to ``True``. A flag consumed + :param enable_baked_queries: legacy; defaults to ``True``. + A parameter consumed by the :mod:`sqlalchemy.ext.baked` extension to determine if "baked queries" should be cached, as is the normal operation of this extension. When set to ``False``, caching as used by @@ -1331,7 +1428,7 @@ class Session(_SessionClassMethods): else: self._transaction.rollback(_to_root=True) - def commit(self): + def commit(self) -> None: """Flush pending changes and commit the current transaction. If no transaction is in progress, the method will first @@ -1353,7 +1450,7 @@ class Session(_SessionClassMethods): self._transaction.commit(_to_root=True) - def prepare(self): + def prepare(self) -> None: """Prepare the current transaction in progress for two phase commit. If no transaction is in progress, this method raises an @@ -1370,7 +1467,11 @@ class Session(_SessionClassMethods): self._transaction.prepare() - def connection(self, bind_arguments=None, execution_options=None): + def connection( + self, + bind_arguments: Optional[Dict[str, Any]] = None, + execution_options: Optional["_ExecuteOptions"] = None, + ) -> "Connection": r"""Return a :class:`_engine.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -1425,12 +1526,12 @@ class Session(_SessionClassMethods): def execute( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - _parent_execute_state=None, - _add_event=None, + statement: "Executable", + params: Optional["_ExecuteParams"] = None, + execution_options: "_ExecuteOptions" = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, ): r"""Execute a SQL expression construct. @@ -1936,7 +2037,9 @@ class Session(_SessionClassMethods): % (", ".join(context),), ) - def query(self, *entities, **kwargs): + def query( + self, *entities: "_ColumnsClauseElement", **kwargs: Any + ) -> "Query": """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -2391,7 +2494,7 @@ class Session(_SessionClassMethods): if persistent_to_deleted is not None: persistent_to_deleted(self, state) - def add(self, instance, _warn=True): + def add(self, instance: Any, _warn: bool = True) -> None: """Place an object in the ``Session``. Its state will be persisted to the database on the next flush diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 07e71d4c0..316aa7ed7 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -34,7 +34,7 @@ from .interfaces import StrategizedProperty from .session import _state_session from .state import InstanceState from .util import _none_set -from .util import aliased +from .util import AliasedClass from .. import event from .. import exc as sa_exc from .. import inspect @@ -564,7 +564,7 @@ class AbstractRelationshipLoader(LoaderStrategy): @log.class_logger -@relationships.RelationshipProperty.strategy_for(do_nothing=True) +@relationships.Relationship.strategy_for(do_nothing=True) class DoNothingLoader(LoaderStrategy): """Relationship loader that makes no change to the object's state. @@ -576,10 +576,10 @@ class DoNothingLoader(LoaderStrategy): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="noload") -@relationships.RelationshipProperty.strategy_for(lazy=None) +@relationships.Relationship.strategy_for(lazy="noload") +@relationships.Relationship.strategy_for(lazy=None) class NoLoader(AbstractRelationshipLoader): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` with "lazy=None". """ @@ -617,13 +617,13 @@ class NoLoader(AbstractRelationshipLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy=True) -@relationships.RelationshipProperty.strategy_for(lazy="select") -@relationships.RelationshipProperty.strategy_for(lazy="raise") -@relationships.RelationshipProperty.strategy_for(lazy="raise_on_sql") -@relationships.RelationshipProperty.strategy_for(lazy="baked_select") +@relationships.Relationship.strategy_for(lazy=True) +@relationships.Relationship.strategy_for(lazy="select") +@relationships.Relationship.strategy_for(lazy="raise") +@relationships.Relationship.strategy_for(lazy="raise_on_sql") +@relationships.Relationship.strategy_for(lazy="baked_select") class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` with "lazy=True", that is loads when first accessed. """ @@ -1214,7 +1214,7 @@ class PostLoader(AbstractRelationshipLoader): ) -@relationships.RelationshipProperty.strategy_for(lazy="immediate") +@relationships.Relationship.strategy_for(lazy="immediate") class ImmediateLoader(PostLoader): __slots__ = () @@ -1250,7 +1250,7 @@ class ImmediateLoader(PostLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="subquery") +@relationships.Relationship.strategy_for(lazy="subquery") class SubqueryLoader(PostLoader): __slots__ = ("join_depth",) @@ -1906,10 +1906,10 @@ class SubqueryLoader(PostLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="joined") -@relationships.RelationshipProperty.strategy_for(lazy=False) +@relationships.Relationship.strategy_for(lazy="joined") +@relationships.Relationship.strategy_for(lazy=False) class JoinedLoader(AbstractRelationshipLoader): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` using joined eager loading. """ @@ -2628,7 +2628,7 @@ class JoinedLoader(AbstractRelationshipLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="selectin") +@relationships.Relationship.strategy_for(lazy="selectin") class SelectInLoader(PostLoader, util.MemoizedSlots): __slots__ = ( "join_depth", @@ -2721,7 +2721,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): ) def _init_for_join(self): - self._parent_alias = aliased(self.parent.class_) + self._parent_alias = AliasedClass(self.parent.class_) pa_insp = inspect(self._parent_alias) pk_cols = [ pa_insp._adapt_element(col) for col in self.parent.primary_key diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 0f993b86c..3f093e543 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1808,7 +1808,7 @@ class _AttributeStrategyLoad(_LoadElement): assert pwpi if not pwpi.is_aliased_class: pwpi = inspect( - orm_util.with_polymorphic( + orm_util.AliasedInsp._with_polymorphic_factory( pwpi.mapper.base_mapper, pwpi.mapper, aliased=True, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 75f711007..45c578355 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -5,13 +5,22 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php - import re import types +import typing +from typing import Any +from typing import Generic +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union import weakref from . import attributes # noqa from .base import _class_to_mapper # noqa +from .base import _IdentityKeyType from .base import _never_set # noqa from .base import _none_set # noqa from .base import attribute_str # noqa @@ -45,8 +54,17 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection +from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots +from ..util.typing import de_stringify_annotation +from ..util.typing import is_origin_of + +if typing.TYPE_CHECKING: + from .mapper import Mapper + from ..engine import Row + from ..sql.selectable import Alias +_T = TypeVar("_T", bound=Any) all_cascades = frozenset( ( @@ -276,7 +294,28 @@ def polymorphic_union( return sql.union_all(*result).alias(aliasname) -def identity_key(*args, **kwargs): +@overload +def identity_key( + class_: type, ident: Tuple[Any, ...], *, identity_token: Optional[str] +) -> _IdentityKeyType: + ... + + +@overload +def identity_key(*, instance: Any) -> _IdentityKeyType: + ... + + +@overload +def identity_key( + class_: type, *, row: "Row", identity_token: Optional[str] +) -> _IdentityKeyType: + ... + + +def identity_key( + class_=None, ident=None, *, instance=None, row=None, identity_token=None +) -> _IdentityKeyType: r"""Generate "identity key" tuples, as are used as keys in the :attr:`.Session.identity_map` dictionary. @@ -340,29 +379,11 @@ def identity_key(*args, **kwargs): .. versionadded:: 1.2 added identity_token """ - if args: - row = None - largs = len(args) - if largs == 1: - class_ = args[0] - try: - row = kwargs.pop("row") - except KeyError: - ident = kwargs.pop("ident") - elif largs in (2, 3): - class_, ident = args - else: - raise sa_exc.ArgumentError( - "expected up to three positional arguments, " "got %s" % largs - ) - - identity_token = kwargs.pop("identity_token", None) - if kwargs: - raise sa_exc.ArgumentError( - "unknown keyword arguments: %s" % ", ".join(kwargs) - ) + if class_ is not None: mapper = class_mapper(class_) if row is None: + if ident is None: + raise sa_exc.ArgumentError("ident or row is required") return mapper.identity_key_from_primary_key( util.to_list(ident), identity_token=identity_token ) @@ -370,14 +391,11 @@ def identity_key(*args, **kwargs): return mapper.identity_key_from_row( row, identity_token=identity_token ) - else: - instance = kwargs.pop("instance") - if kwargs: - raise sa_exc.ArgumentError( - "unknown keyword arguments: %s" % ", ".join(kwargs.keys) - ) + elif instance is not None: mapper = object_mapper(instance) return mapper.identity_key_from_instance(instance) + else: + raise sa_exc.ArgumentError("class or instance is required") class ORMAdapter(sql_util.ColumnAdapter): @@ -420,7 +438,7 @@ class ORMAdapter(sql_util.ColumnAdapter): return not entity or entity.isa(self.mapper) -class AliasedClass: +class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` @@ -481,7 +499,7 @@ class AliasedClass: def __init__( self, - mapped_class_or_ac, + mapped_class_or_ac: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], alias=None, name=None, flat=False, @@ -611,6 +629,7 @@ class AliasedInsp( ORMEntityColumnsClauseRole, ORMFromClauseRole, sql_base.HasCacheKey, + roles.HasFromClauseElement, InspectionAttr, MemoizedSlots, ): @@ -747,6 +766,73 @@ class AliasedInsp( self._target = mapped_class_or_ac # self._target = mapper.class_ # mapped_class_or_ac + @classmethod + def _alias_factory( + cls, + element: Union[ + Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]" + ], + alias=None, + name=None, + flat=False, + adapt_on_names=False, + ) -> Union["AliasedClass[_T]", "Alias"]: + + if isinstance(element, FromClause): + if adapt_on_names: + raise sa_exc.ArgumentError( + "adapt_on_names only applies to ORM elements" + ) + if name: + return element.alias(name=name, flat=flat) + else: + return coercions.expect( + roles.AnonymizedFromClauseRole, element, flat=flat + ) + else: + return AliasedClass( + element, + alias=alias, + flat=flat, + name=name, + adapt_on_names=adapt_on_names, + ) + + @classmethod + def _with_polymorphic_factory( + cls, + base, + classes, + selectable=False, + flat=False, + polymorphic_on=None, + aliased=False, + innerjoin=False, + _use_mapper_path=False, + ): + + primary_mapper = _class_to_mapper(base) + + if selectable not in (None, False) and flat: + raise sa_exc.ArgumentError( + "the 'flat' and 'selectable' arguments cannot be passed " + "simultaneously to with_polymorphic()" + ) + + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, selectable, innerjoin=innerjoin + ) + if aliased or flat: + selectable = selectable._anonymous_fromclause(flat=flat) + return AliasedClass( + base, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path, + represents_outer_join=not innerjoin, + ) + @property def entity(self): # to eliminate reference cycles, the AliasedClass is held weakly. @@ -1107,215 +1193,6 @@ inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) inspection._inspects(AliasedInsp)(lambda target: target) -def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): - """Produce an alias of the given element, usually an :class:`.AliasedClass` - instance. - - E.g.:: - - my_alias = aliased(MyClass) - - session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) - - The :func:`.aliased` function is used to create an ad-hoc mapping of a - mapped class to a new selectable. By default, a selectable is generated - from the normally mapped selectable (typically a :class:`_schema.Table` - ) using the - :meth:`_expression.FromClause.alias` method. However, :func:`.aliased` - can also be - used to link the class to a new :func:`_expression.select` statement. - Also, the :func:`.with_polymorphic` function is a variant of - :func:`.aliased` that is intended to specify a so-called "polymorphic - selectable", that corresponds to the union of several joined-inheritance - subclasses at once. - - For convenience, the :func:`.aliased` function also accepts plain - :class:`_expression.FromClause` constructs, such as a - :class:`_schema.Table` or - :func:`_expression.select` construct. In those cases, the - :meth:`_expression.FromClause.alias` - method is called on the object and the new - :class:`_expression.Alias` object returned. The returned - :class:`_expression.Alias` is not - ORM-mapped in this case. - - .. seealso:: - - :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` - - :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` - - :ref:`ormtutorial_aliases` - in the legacy :ref:`ormtutorial_toplevel` - - :param element: element to be aliased. Is normally a mapped class, - but for convenience can also be a :class:`_expression.FromClause` - element. - - :param alias: Optional selectable unit to map the element to. This is - usually used to link the object to a subquery, and should be an aliased - select construct as one would produce from the - :meth:`_query.Query.subquery` method or - the :meth:`_expression.Select.subquery` or - :meth:`_expression.Select.alias` methods of the :func:`_expression.select` - construct. - - :param name: optional string name to use for the alias, if not specified - by the ``alias`` parameter. The name, among other things, forms the - attribute name that will be accessible via tuples returned by a - :class:`_query.Query` object. Not supported when creating aliases - of :class:`_sql.Join` objects. - - :param flat: Boolean, will be passed through to the - :meth:`_expression.FromClause.alias` call so that aliases of - :class:`_expression.Join` objects will alias the individual tables - inside the join, rather than creating a subquery. This is generally - supported by all modern databases with regards to right-nested joins - and generally produces more efficient queries. - - :param adapt_on_names: if True, more liberal "matching" will be used when - mapping the mapped columns of the ORM entity to those of the - given selectable - a name-based match will be performed if the - given selectable doesn't otherwise have a column that corresponds - to one on the entity. The use case for this is when associating - an entity with some derived selectable such as one that uses - aggregate functions:: - - class UnitPrice(Base): - __tablename__ = 'unit_price' - ... - unit_id = Column(Integer) - price = Column(Numeric) - - aggregated_unit_price = Session.query( - func.sum(UnitPrice.price).label('price') - ).group_by(UnitPrice.unit_id).subquery() - - aggregated_unit_price = aliased(UnitPrice, - alias=aggregated_unit_price, adapt_on_names=True) - - Above, functions on ``aggregated_unit_price`` which refer to - ``.price`` will return the - ``func.sum(UnitPrice.price).label('price')`` column, as it is - matched on the name "price". Ordinarily, the "price" function - wouldn't have any "column correspondence" to the actual - ``UnitPrice.price`` column as it is not a proxy of the original. - - """ - if isinstance(element, expression.FromClause): - if adapt_on_names: - raise sa_exc.ArgumentError( - "adapt_on_names only applies to ORM elements" - ) - if name: - return element.alias(name=name, flat=flat) - else: - return coercions.expect( - roles.AnonymizedFromClauseRole, element, flat=flat - ) - else: - return AliasedClass( - element, - alias=alias, - flat=flat, - name=name, - adapt_on_names=adapt_on_names, - ) - - -def with_polymorphic( - base, - classes, - selectable=False, - flat=False, - polymorphic_on=None, - aliased=False, - innerjoin=False, - _use_mapper_path=False, -): - """Produce an :class:`.AliasedClass` construct which specifies - columns for descendant mappers of the given base. - - Using this method will ensure that each descendant mapper's - tables are included in the FROM clause, and will allow filter() - criterion to be used against those tables. The resulting - instances will also have those columns already loaded so that - no "post fetch" of those columns will be required. - - .. seealso:: - - :ref:`with_polymorphic` - full discussion of - :func:`_orm.with_polymorphic`. - - :param base: Base class to be aliased. - - :param classes: a single class or mapper, or list of - class/mappers, which inherit from the base class. - Alternatively, it may also be the string ``'*'``, in which case - all descending mapped classes will be added to the FROM clause. - - :param aliased: when True, the selectable will be aliased. For a - JOIN, this means the JOIN will be SELECTed from inside of a subquery - unless the :paramref:`_orm.with_polymorphic.flat` flag is set to - True, which is recommended for simpler use cases. - - :param flat: Boolean, will be passed through to the - :meth:`_expression.FromClause.alias` call so that aliases of - :class:`_expression.Join` objects will alias the individual tables - inside the join, rather than creating a subquery. This is generally - supported by all modern databases with regards to right-nested joins - and generally produces more efficient queries. Setting this flag is - recommended as long as the resulting SQL is functional. - - :param selectable: a table or subquery that will - be used in place of the generated FROM clause. This argument is - required if any of the desired classes use concrete table - inheritance, since SQLAlchemy currently cannot generate UNIONs - among tables automatically. If used, the ``selectable`` argument - must represent the full set of tables and columns mapped by every - mapped class. Otherwise, the unaccounted mapped columns will - result in their table being appended directly to the FROM clause - which will usually lead to incorrect results. - - When left at its default value of ``False``, the polymorphic - selectable assigned to the base mapper is used for selecting rows. - However, it may also be passed as ``None``, which will bypass the - configured polymorphic selectable and instead construct an ad-hoc - selectable for the target classes given; for joined table inheritance - this will be a join that includes all target mappers and their - subclasses. - - :param polymorphic_on: a column to be used as the "discriminator" - column for the given selectable. If not given, the polymorphic_on - attribute of the base classes' mapper will be used, if any. This - is useful for mappings that don't have polymorphic loading - behavior by default. - - :param innerjoin: if True, an INNER JOIN will be used. This should - only be specified if querying for one specific subtype only - """ - primary_mapper = _class_to_mapper(base) - - if selectable not in (None, False) and flat: - raise sa_exc.ArgumentError( - "the 'flat' and 'selectable' arguments cannot be passed " - "simultaneously to with_polymorphic()" - ) - - mappers, selectable = primary_mapper._with_polymorphic_args( - classes, selectable, innerjoin=innerjoin - ) - if aliased or flat: - selectable = selectable._anonymous_fromclause(flat=flat) - return AliasedClass( - base, - selectable, - with_polymorphic_mappers=mappers, - with_polymorphic_discriminator=polymorphic_on, - use_mapper_path=_use_mapper_path, - represents_outer_join=not innerjoin, - ) - - @inspection._self_inspects class Bundle( ORMColumnsClauseRole, @@ -1667,62 +1544,6 @@ class _ORMJoin(expression.Join): return _ORMJoin(self, right, onclause, isouter=True, full=full) -def join( - left, right, onclause=None, isouter=False, full=False, join_to_left=None -): - r"""Produce an inner join between left and right clauses. - - :func:`_orm.join` is an extension to the core join interface - provided by :func:`_expression.join()`, where the - left and right selectables may be not only core selectable - objects such as :class:`_schema.Table`, but also mapped classes or - :class:`.AliasedClass` instances. The "on" clause can - be a SQL expression, or an attribute or string name - referencing a configured :func:`_orm.relationship`. - - :func:`_orm.join` is not commonly needed in modern usage, - as its functionality is encapsulated within that of the - :meth:`_query.Query.join` method, which features a - significant amount of automation beyond :func:`_orm.join` - by itself. Explicit usage of :func:`_orm.join` - with :class:`_query.Query` involves usage of the - :meth:`_query.Query.select_from` method, as in:: - - from sqlalchemy.orm import join - session.query(User).\ - select_from(join(User, Address, User.addresses)).\ - filter(Address.email_address=='foo@bar.com') - - In modern SQLAlchemy the above join can be written more - succinctly as:: - - session.query(User).\ - join(User.addresses).\ - filter(Address.email_address=='foo@bar.com') - - See :meth:`_query.Query.join` for information on modern usage - of ORM level joins. - - .. deprecated:: 0.8 - - the ``join_to_left`` parameter is deprecated, and will be removed - in a future release. The parameter has no effect. - - """ - return _ORMJoin(left, right, onclause, isouter, full) - - -def outerjoin(left, right, onclause=None, full=False, join_to_left=None): - """Produce a left outer join between left and right clauses. - - This is the "outer join" version of the :func:`_orm.join` function, - featuring the same behavior except that an OUTER JOIN is generated. - See that function's documentation for other usage details. - - """ - return _ORMJoin(left, right, onclause, True, full) - - def with_parent(instance, prop, from_entity=None): """Create filtering criterion that relates this query's primary entity to the given related instance, using established @@ -1964,3 +1785,56 @@ def _getitem(iterable_query, item): return list(iterable_query)[-1] else: return list(iterable_query[item : item + 1])[0] + + +def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type): + annotated = de_stringify_annotation(cls, raw_annotation) + return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") + + +def _extract_mapped_subtype( + raw_annotation: Union[type, str], + cls: type, + key: str, + attr_cls: type, + required: bool, + is_dataclass_field: bool, +) -> Optional[Union[type, str]]: + + if raw_annotation is None: + + if required: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{attr_cls.__name__}" construct are None or not present' + ) + return None + + annotated = de_stringify_annotation(cls, raw_annotation) + + if is_dataclass_field: + return annotated + else: + if ( + not hasattr(annotated, "__origin__") + or not issubclass(annotated.__origin__, attr_cls) + and not issubclass(attr_cls, annotated.__origin__) + ): + our_annotated_str = ( + annotated.__name__ + if not isinstance(annotated, str) + else repr(annotated) + ) + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" should use the ' + f'syntax "Mapped[{our_annotated_str}]" or ' + f'"{attr_cls.__name__}[{our_annotated_str}]".' + ) + + if len(annotated.__args__) != 1: + raise sa_exc.ArgumentError( + "Expected sub-type for Mapped[] annotation" + ) + + return annotated.__args__[0] diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 38059856e..bc2f93d57 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -22,36 +22,17 @@ from .base import _AdhocProxiedConnection from .base import _ConnectionFairy from .base import _ConnectionRecord from .base import _finalize_fairy -from .base import Pool -from .base import PoolProxiedConnection -from .base import reset_commit -from .base import reset_none -from .base import reset_rollback -from .impl import AssertionPool -from .impl import AsyncAdaptedQueuePool -from .impl import FallbackAsyncAdaptedQueuePool -from .impl import NullPool -from .impl import QueuePool -from .impl import SingletonThreadPool -from .impl import StaticPool - - -__all__ = [ - "Pool", - "PoolProxiedConnection", - "reset_commit", - "reset_none", - "reset_rollback", - "clear_managers", - "manage", - "AssertionPool", - "NullPool", - "QueuePool", - "AsyncAdaptedQueuePool", - "FallbackAsyncAdaptedQueuePool", - "SingletonThreadPool", - "StaticPool", -] - -# as these are likely to be used in various test suites, debugging -# setups, keep them in the sqlalchemy.pool namespace +from .base import Pool as Pool +from .base import PoolProxiedConnection as PoolProxiedConnection +from .base import reset_commit as reset_commit +from .base import reset_none as reset_none +from .base import reset_rollback as reset_rollback +from .impl import AssertionPool as AssertionPool +from .impl import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .impl import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .impl import NullPool as NullPool +from .impl import QueuePool as QueuePool +from .impl import SingletonThreadPool as SingletonThreadPool +from .impl import StaticPool as StaticPool diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index c596dee5a..b2ca1cfef 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -9,50 +9,54 @@ """ -from .sql.base import SchemaVisitor # noqa -from .sql.ddl import _CreateDropBase # noqa -from .sql.ddl import _DDLCompiles # noqa -from .sql.ddl import _DropView # noqa -from .sql.ddl import AddConstraint # noqa -from .sql.ddl import CreateColumn # noqa -from .sql.ddl import CreateIndex # noqa -from .sql.ddl import CreateSchema # noqa -from .sql.ddl import CreateSequence # noqa -from .sql.ddl import CreateTable # noqa -from .sql.ddl import DDL # noqa -from .sql.ddl import DDLBase # noqa -from .sql.ddl import DDLElement # noqa -from .sql.ddl import DropColumnComment # noqa -from .sql.ddl import DropConstraint # noqa -from .sql.ddl import DropIndex # noqa -from .sql.ddl import DropSchema # noqa -from .sql.ddl import DropSequence # noqa -from .sql.ddl import DropTable # noqa -from .sql.ddl import DropTableComment # noqa -from .sql.ddl import SetColumnComment # noqa -from .sql.ddl import SetTableComment # noqa -from .sql.ddl import sort_tables # noqa -from .sql.ddl import sort_tables_and_constraints # noqa -from .sql.naming import conv # noqa -from .sql.schema import _get_table_key # noqa -from .sql.schema import BLANK_SCHEMA # noqa -from .sql.schema import CheckConstraint # noqa -from .sql.schema import Column # noqa -from .sql.schema import ColumnCollectionConstraint # noqa -from .sql.schema import ColumnCollectionMixin # noqa -from .sql.schema import ColumnDefault # noqa -from .sql.schema import Computed # noqa -from .sql.schema import Constraint # noqa -from .sql.schema import DefaultClause # noqa -from .sql.schema import DefaultGenerator # noqa -from .sql.schema import FetchedValue # noqa -from .sql.schema import ForeignKey # noqa -from .sql.schema import ForeignKeyConstraint # noqa -from .sql.schema import Identity # noqa -from .sql.schema import Index # noqa -from .sql.schema import MetaData # noqa -from .sql.schema import PrimaryKeyConstraint # noqa -from .sql.schema import SchemaItem # noqa -from .sql.schema import Sequence # noqa -from .sql.schema import Table # noqa -from .sql.schema import UniqueConstraint # noqa +from .sql.base import SchemaVisitor as SchemaVisitor +from .sql.ddl import _CreateDropBase as _CreateDropBase +from .sql.ddl import _DDLCompiles as _DDLCompiles +from .sql.ddl import _DropView as _DropView +from .sql.ddl import AddConstraint as AddConstraint +from .sql.ddl import CreateColumn as CreateColumn +from .sql.ddl import CreateIndex as CreateIndex +from .sql.ddl import CreateSchema as CreateSchema +from .sql.ddl import CreateSequence as CreateSequence +from .sql.ddl import CreateTable as CreateTable +from .sql.ddl import DDL as DDL +from .sql.ddl import DDLBase as DDLBase +from .sql.ddl import DDLElement as DDLElement +from .sql.ddl import DropColumnComment as DropColumnComment +from .sql.ddl import DropConstraint as DropConstraint +from .sql.ddl import DropIndex as DropIndex +from .sql.ddl import DropSchema as DropSchema +from .sql.ddl import DropSequence as DropSequence +from .sql.ddl import DropTable as DropTable +from .sql.ddl import DropTableComment as DropTableComment +from .sql.ddl import SetColumnComment as SetColumnComment +from .sql.ddl import SetTableComment as SetTableComment +from .sql.ddl import sort_tables as sort_tables +from .sql.ddl import ( + sort_tables_and_constraints as sort_tables_and_constraints, +) +from .sql.naming import conv as conv +from .sql.schema import _get_table_key as _get_table_key +from .sql.schema import BLANK_SCHEMA as BLANK_SCHEMA +from .sql.schema import CheckConstraint as CheckConstraint +from .sql.schema import Column as Column +from .sql.schema import ( + ColumnCollectionConstraint as ColumnCollectionConstraint, +) +from .sql.schema import ColumnCollectionMixin as ColumnCollectionMixin +from .sql.schema import ColumnDefault as ColumnDefault +from .sql.schema import Computed as Computed +from .sql.schema import Constraint as Constraint +from .sql.schema import DefaultClause as DefaultClause +from .sql.schema import DefaultGenerator as DefaultGenerator +from .sql.schema import FetchedValue as FetchedValue +from .sql.schema import ForeignKey as ForeignKey +from .sql.schema import ForeignKeyConstraint as ForeignKeyConstraint +from .sql.schema import Identity as Identity +from .sql.schema import Index as Index +from .sql.schema import MetaData as MetaData +from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from .sql.schema import SchemaItem as SchemaItem +from .sql.schema import Sequence as Sequence +from .sql.schema import Table as Table +from .sql.schema import UniqueConstraint as UniqueConstraint diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 2f84370aa..169ddf3db 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -75,6 +75,7 @@ from .expression import quoted_name as quoted_name from .expression import Select as Select from .expression import select as select from .expression import Selectable as Selectable +from .expression import SelectLabelStyle as SelectLabelStyle from .expression import StatementLambdaElement as StatementLambdaElement from .expression import Subquery as Subquery from .expression import table as table diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 4b67c12f0..d3cf207da 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -6,11 +6,11 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from typing import Any -from typing import Type from typing import Union from . import coercions from . import roles +from ._typing import _ColumnsClauseElement from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect @@ -21,6 +21,8 @@ from .selectable import Select from .selectable import TableClause from .selectable import TableSample from .selectable import Values +from ..util.typing import _LiteralStar +from ..util.typing import Literal def alias(selectable, name=None, flat=False): @@ -279,7 +281,9 @@ def outerjoin(left, right, onclause=None, full=False): return Join(left, right, onclause, isouter=True, full=full) -def select(*entities: Union[roles.ColumnsClauseRole, Type]) -> "Select": +def select( + *entities: Union[_LiteralStar, Literal[1], _ColumnsClauseElement] +) -> "Select": r"""Construct a new :class:`_expression.Select`. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b5b0efb21..4d2dd2688 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,9 +1,21 @@ from typing import Any from typing import Mapping from typing import Sequence +from typing import Type from typing import Union +from . import roles +from ..inspection import Inspectable +from ..util import immutabledict + _SingleExecuteParams = Mapping[str, Any] _MultiExecuteParams = Sequence[_SingleExecuteParams] _ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams] _ExecuteOptions = Mapping[str, Any] +_ImmutableExecuteOptions = immutabledict[str, Any] +_ColumnsClauseElement = Union[ + roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement] +] +_FromClauseElement = Union[ + roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement] +] diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f4fe7afab..5828f9369 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -21,6 +21,7 @@ from typing import TypeVar from . import roles from . import visitors +from ._typing import _ImmutableExecuteOptions from .cache_key import HasCacheKey # noqa from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa @@ -832,9 +833,8 @@ class Executable(roles.StatementRole, Generative): """ - supports_execution = True - _execution_options = util.immutabledict() - _bind = None + supports_execution: bool = True + _execution_options: _ImmutableExecuteOptions = util.immutabledict() _with_options = () _with_context_options = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9cf4d8397..bf78b4231 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -889,7 +889,7 @@ class SQLCompiler(Compiled): def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string + r"\[_POSITION\]", lambda m: str(next(poscount)), self.string ) @util.memoized_property diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 18931ce67..f622023b0 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -10,6 +10,11 @@ to invoke them for a create/drop call. """ import typing +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple from . import roles from .base import _generative @@ -21,6 +26,11 @@ from .. import util from ..util import topological +if typing.TYPE_CHECKING: + from .schema import ForeignKeyConstraint + from .schema import Table + + class _DDLCompiles(ClauseElement): _hierarchy_supports_caching = False """disable cache warnings for all _DDLCompiles subclasses. """ @@ -1007,10 +1017,10 @@ class SchemaDropper(DDLBase): def sort_tables( - tables, - skip_fn=None, - extra_dependencies=None, -): + tables: Sequence["Table"], + skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, + extra_dependencies: Optional[Sequence[Tuple["Table", "Table"]]] = None, +) -> List["Table"]: """Sort a collection of :class:`_schema.Table` objects based on dependency. @@ -1051,7 +1061,7 @@ def sort_tables( :param tables: a sequence of :class:`_schema.Table` objects. :param skip_fn: optional callable which will be passed a - :class:`_schema.ForeignKey` object; if it returns True, this + :class:`_schema.ForeignKeyConstraint` object; if it returns True, this constraint will not be considered as a dependency. Note this is **different** from the same parameter in :func:`.sort_tables_and_constraints`, which is diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 0ed5bd986..22195cd7c 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -136,6 +136,7 @@ from .selectable import ScalarSelect as ScalarSelect from .selectable import Select as Select from .selectable import Selectable as Selectable from .selectable import SelectBase as SelectBase +from .selectable import SelectLabelStyle as SelectLabelStyle from .selectable import Subquery as Subquery from .selectable import TableClause as TableClause from .selectable import TableSample as TableSample diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py index 00a2b1d89..15a1566a6 100644 --- a/lib/sqlalchemy/sql/naming.py +++ b/lib/sqlalchemy/sql/naming.py @@ -14,7 +14,7 @@ import re from . import events # noqa from .elements import _NONE_NAME -from .elements import conv +from .elements import conv as conv from .schema import CheckConstraint from .schema import Column from .schema import Constraint diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 787a1c25e..b41ef7a5d 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -4,10 +4,17 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +import typing +from sqlalchemy.util.langhelpers import TypingOnly from .. import util +if typing.TYPE_CHECKING: + from .elements import ClauseElement + from .selectable import FromClause + + class SQLRole: """Define a "role" within a SQL statement structure. @@ -284,3 +291,25 @@ class DDLReferredColumnRole(DDLConstraintColumnRole): _role_name = ( "String column name or Column object for DDL foreign key constraint" ) + + +class HasClauseElement(TypingOnly): + """indicates a class that has a __clause_element__() method""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + def __clause_element__(self) -> "ClauseElement": + ... + + +class HasFromClauseElement(HasClauseElement, TypingOnly): + """indicates a class that has a __clause_element__() method""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + def __clause_element__(self) -> "FromClause": + ... diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index a04fad05d..9387ae030 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -31,9 +31,12 @@ as components in SQL expressions. import collections import typing from typing import Any +from typing import Dict +from typing import List from typing import MutableMapping from typing import Optional from typing import overload +from typing import Sequence as _typing_Sequence from typing import Type from typing import TypeVar from typing import Union @@ -52,6 +55,7 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement from .elements import quoted_name +from .elements import SQLCoreOperations from .elements import TextClause from .selectable import TableClause from .type_api import to_instance @@ -64,9 +68,12 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from .type_api import TypeEngine + from ..engine import Connection + from ..engine import Engine _T = TypeVar("_T", bound="Any") _ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement] +_TAB = TypeVar("_TAB", bound="Table") RETAIN_SCHEMA = util.symbol("retain_schema") @@ -188,313 +195,6 @@ class Table(DialectKWArgs, SchemaItem, TableClause): :ref:`metadata_describing` - Introduction to database metadata - Constructor arguments are as follows: - - :param name: The name of this table as represented in the database. - - The table name, along with the value of the ``schema`` parameter, - forms a key which uniquely identifies this :class:`_schema.Table` - within - the owning :class:`_schema.MetaData` collection. - Additional calls to :class:`_schema.Table` with the same name, - metadata, - and schema name will return the same :class:`_schema.Table` object. - - Names which contain no upper case characters - will be treated as case insensitive names, and will not be quoted - unless they are a reserved word or contain special characters. - A name with any number of upper case characters is considered - to be case sensitive, and will be sent as quoted. - - To enable unconditional quoting for the table name, specify the flag - ``quote=True`` to the constructor, or use the :class:`.quoted_name` - construct to specify the name. - - :param metadata: a :class:`_schema.MetaData` - object which will contain this - table. The metadata is used as a point of association of this table - with other tables which are referenced via foreign key. It also - may be used to associate this table with a particular - :class:`.Connection` or :class:`.Engine`. - - :param \*args: Additional positional arguments are used primarily - to add the list of :class:`_schema.Column` - objects contained within this - table. Similar to the style of a CREATE TABLE statement, other - :class:`.SchemaItem` constructs may be added here, including - :class:`.PrimaryKeyConstraint`, and - :class:`_schema.ForeignKeyConstraint`. - - :param autoload: Defaults to ``False``, unless - :paramref:`_schema.Table.autoload_with` - is set in which case it defaults to ``True``; - :class:`_schema.Column` objects - for this table should be reflected from the database, possibly - augmenting objects that were explicitly specified. - :class:`_schema.Column` and other objects explicitly set on the - table will replace corresponding reflected objects. - - .. deprecated:: 1.4 - - The autoload parameter is deprecated and will be removed in - version 2.0. Please use the - :paramref:`_schema.Table.autoload_with` parameter, passing an - engine or connection. - - .. seealso:: - - :ref:`metadata_reflection_toplevel` - - :param autoload_replace: Defaults to ``True``; when using - :paramref:`_schema.Table.autoload` - in conjunction with :paramref:`_schema.Table.extend_existing`, - indicates - that :class:`_schema.Column` objects present in the already-existing - :class:`_schema.Table` - object should be replaced with columns of the same - name retrieved from the autoload process. When ``False``, columns - already present under existing names will be omitted from the - reflection process. - - Note that this setting does not impact :class:`_schema.Column` objects - specified programmatically within the call to :class:`_schema.Table` - that - also is autoloading; those :class:`_schema.Column` objects will always - replace existing columns of the same name when - :paramref:`_schema.Table.extend_existing` is ``True``. - - .. seealso:: - - :paramref:`_schema.Table.autoload` - - :paramref:`_schema.Table.extend_existing` - - :param autoload_with: An :class:`_engine.Engine` or - :class:`_engine.Connection` object, - or a :class:`_reflection.Inspector` object as returned by - :func:`_sa.inspect` - against one, with which this :class:`_schema.Table` - object will be reflected. - When set to a non-None value, the autoload process will take place - for this table against the given engine or connection. - - :param extend_existing: When ``True``, indicates that if this - :class:`_schema.Table` is already present in the given - :class:`_schema.MetaData`, - apply further arguments within the constructor to the existing - :class:`_schema.Table`. - - If :paramref:`_schema.Table.extend_existing` or - :paramref:`_schema.Table.keep_existing` are not set, - and the given name - of the new :class:`_schema.Table` refers to a :class:`_schema.Table` - that is - already present in the target :class:`_schema.MetaData` collection, - and - this :class:`_schema.Table` - specifies additional columns or other constructs - or flags that modify the table's state, an - error is raised. The purpose of these two mutually-exclusive flags - is to specify what action should be taken when a - :class:`_schema.Table` - is specified that matches an existing :class:`_schema.Table`, - yet specifies - additional constructs. - - :paramref:`_schema.Table.extend_existing` - will also work in conjunction - with :paramref:`_schema.Table.autoload` to run a new reflection - operation against the database, even if a :class:`_schema.Table` - of the same name is already present in the target - :class:`_schema.MetaData`; newly reflected :class:`_schema.Column` - objects - and other options will be added into the state of the - :class:`_schema.Table`, potentially overwriting existing columns - and options of the same name. - - As is always the case with :paramref:`_schema.Table.autoload`, - :class:`_schema.Column` objects can be specified in the same - :class:`_schema.Table` - constructor, which will take precedence. Below, the existing - table ``mytable`` will be augmented with :class:`_schema.Column` - objects - both reflected from the database, as well as the given - :class:`_schema.Column` - named "y":: - - Table("mytable", metadata, - Column('y', Integer), - extend_existing=True, - autoload_with=engine - ) - - .. seealso:: - - :paramref:`_schema.Table.autoload` - - :paramref:`_schema.Table.autoload_replace` - - :paramref:`_schema.Table.keep_existing` - - - :param implicit_returning: True by default - indicates that - RETURNING can be used by default to fetch newly inserted primary key - values, for backends which support this. Note that - :func:`_sa.create_engine` also provides an ``implicit_returning`` - flag. - - :param include_columns: A list of strings indicating a subset of - columns to be loaded via the ``autoload`` operation; table columns who - aren't present in this list will not be represented on the resulting - ``Table`` object. Defaults to ``None`` which indicates all columns - should be reflected. - - :param resolve_fks: Whether or not to reflect :class:`_schema.Table` - objects - related to this one via :class:`_schema.ForeignKey` objects, when - :paramref:`_schema.Table.autoload` or - :paramref:`_schema.Table.autoload_with` is - specified. Defaults to True. Set to False to disable reflection of - related tables as :class:`_schema.ForeignKey` - objects are encountered; may be - used either to save on SQL calls or to avoid issues with related tables - that can't be accessed. Note that if a related table is already present - in the :class:`_schema.MetaData` collection, or becomes present later, - a - :class:`_schema.ForeignKey` object associated with this - :class:`_schema.Table` will - resolve to that table normally. - - .. versionadded:: 1.3 - - .. seealso:: - - :paramref:`.MetaData.reflect.resolve_fks` - - - :param info: Optional data dictionary which will be populated into the - :attr:`.SchemaItem.info` attribute of this object. - - :param keep_existing: When ``True``, indicates that if this Table - is already present in the given :class:`_schema.MetaData`, ignore - further arguments within the constructor to the existing - :class:`_schema.Table`, and return the :class:`_schema.Table` - object as - originally created. This is to allow a function that wishes - to define a new :class:`_schema.Table` on first call, but on - subsequent calls will return the same :class:`_schema.Table`, - without any of the declarations (particularly constraints) - being applied a second time. - - If :paramref:`_schema.Table.extend_existing` or - :paramref:`_schema.Table.keep_existing` are not set, - and the given name - of the new :class:`_schema.Table` refers to a :class:`_schema.Table` - that is - already present in the target :class:`_schema.MetaData` collection, - and - this :class:`_schema.Table` - specifies additional columns or other constructs - or flags that modify the table's state, an - error is raised. The purpose of these two mutually-exclusive flags - is to specify what action should be taken when a - :class:`_schema.Table` - is specified that matches an existing :class:`_schema.Table`, - yet specifies - additional constructs. - - .. seealso:: - - :paramref:`_schema.Table.extend_existing` - - :param listeners: A list of tuples of the form ``(<eventname>, <fn>)`` - which will be passed to :func:`.event.listen` upon construction. - This alternate hook to :func:`.event.listen` allows the establishment - of a listener function specific to this :class:`_schema.Table` before - the "autoload" process begins. Historically this has been intended - for use with the :meth:`.DDLEvents.column_reflect` event, however - note that this event hook may now be associated with the - :class:`_schema.MetaData` object directly:: - - def listen_for_reflect(table, column_info): - "handle the column reflection event" - # ... - - t = Table( - 'sometable', - autoload_with=engine, - listeners=[ - ('column_reflect', listen_for_reflect) - ]) - - .. seealso:: - - :meth:`_events.DDLEvents.column_reflect` - - :param must_exist: When ``True``, indicates that this Table must already - be present in the given :class:`_schema.MetaData` collection, else - an exception is raised. - - :param prefixes: - A list of strings to insert after CREATE in the CREATE TABLE - statement. They will be separated by spaces. - - :param quote: Force quoting of this table's name on or off, corresponding - to ``True`` or ``False``. When left at its default of ``None``, - the column identifier will be quoted according to whether the name is - case sensitive (identifiers with at least one upper case character are - treated as case sensitive), or if it's a reserved word. This flag - is only needed to force quoting of a reserved word which is not known - by the SQLAlchemy dialect. - - .. note:: setting this flag to ``False`` will not provide - case-insensitive behavior for table reflection; table reflection - will always search for a mixed-case name in a case sensitive - fashion. Case insensitive names are specified in SQLAlchemy only - by stating the name with all lower case characters. - - :param quote_schema: same as 'quote' but applies to the schema identifier. - - :param schema: The schema name for this table, which is required if - the table resides in a schema other than the default selected schema - for the engine's database connection. Defaults to ``None``. - - If the owning :class:`_schema.MetaData` of this :class:`_schema.Table` - specifies its - own :paramref:`_schema.MetaData.schema` parameter, - then that schema name will - be applied to this :class:`_schema.Table` - if the schema parameter here is set - to ``None``. To set a blank schema name on a :class:`_schema.Table` - that - would otherwise use the schema set on the owning - :class:`_schema.MetaData`, - specify the special symbol :attr:`.BLANK_SCHEMA`. - - .. versionadded:: 1.0.14 Added the :attr:`.BLANK_SCHEMA` symbol to - allow a :class:`_schema.Table` - to have a blank schema name even when the - parent :class:`_schema.MetaData` specifies - :paramref:`_schema.MetaData.schema`. - - The quoting rules for the schema name are the same as those for the - ``name`` parameter, in that quoting is applied for reserved words or - case-sensitive names; to enable unconditional quoting for the schema - name, specify the flag ``quote_schema=True`` to the constructor, or use - the :class:`.quoted_name` construct to specify the name. - - :param comment: Optional string that will render an SQL comment on table - creation. - - .. versionadded:: 1.2 Added the :paramref:`_schema.Table.comment` - parameter - to :class:`_schema.Table`. - - :param \**kw: Additional keyword arguments not mentioned above are - dialect specific, and passed in the form ``<dialectname>_<argname>``. - See the documentation regarding an individual dialect at - :ref:`dialect_toplevel` for detail on documented arguments. - """ __visit_name__ = "table" @@ -547,13 +247,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause): else: return (self,) - @util.deprecated_params( - mustexist=( - "1.4", - "Deprecated alias of :paramref:`_schema.Table.must_exist`", - ), - ) - def __new__(cls, *args, **kw): + if not typing.TYPE_CHECKING: + # typing tools seem to be inconsistent in how they handle + # __new__, so suggest this pattern for classes that use + # __new__. apply typing to the __init__ method normally + @util.deprecated_params( + mustexist=( + "1.4", + "Deprecated alias of :paramref:`_schema.Table.must_exist`", + ), + ) + def __new__(cls, *args: Any, **kw: Any) -> Any: + return cls._new(*args, **kw) + + @classmethod + def _new(cls, *args, **kw): if not args and not kw: # python3k pickle seems to call this return object.__new__(cls) @@ -607,14 +315,323 @@ class Table(DialectKWArgs, SchemaItem, TableClause): with util.safe_reraise(): metadata._remove_table(name, schema) - def __init__(self, *args, **kw): - """Constructor for :class:`_schema.Table`. + def __init__( + self, + name: str, + metadata: "MetaData", + *args: SchemaItem, + **kw: Any, + ): + r"""Constructor for :class:`_schema.Table`. - This method is a no-op. See the top-level - documentation for :class:`_schema.Table` - for constructor arguments. - """ + :param name: The name of this table as represented in the database. + + The table name, along with the value of the ``schema`` parameter, + forms a key which uniquely identifies this :class:`_schema.Table` + within + the owning :class:`_schema.MetaData` collection. + Additional calls to :class:`_schema.Table` with the same name, + metadata, + and schema name will return the same :class:`_schema.Table` object. + + Names which contain no upper case characters + will be treated as case insensitive names, and will not be quoted + unless they are a reserved word or contain special characters. + A name with any number of upper case characters is considered + to be case sensitive, and will be sent as quoted. + + To enable unconditional quoting for the table name, specify the flag + ``quote=True`` to the constructor, or use the :class:`.quoted_name` + construct to specify the name. + + :param metadata: a :class:`_schema.MetaData` + object which will contain this + table. The metadata is used as a point of association of this table + with other tables which are referenced via foreign key. It also + may be used to associate this table with a particular + :class:`.Connection` or :class:`.Engine`. + + :param \*args: Additional positional arguments are used primarily + to add the list of :class:`_schema.Column` + objects contained within this + table. Similar to the style of a CREATE TABLE statement, other + :class:`.SchemaItem` constructs may be added here, including + :class:`.PrimaryKeyConstraint`, and + :class:`_schema.ForeignKeyConstraint`. + + :param autoload: Defaults to ``False``, unless + :paramref:`_schema.Table.autoload_with` + is set in which case it defaults to ``True``; + :class:`_schema.Column` objects + for this table should be reflected from the database, possibly + augmenting objects that were explicitly specified. + :class:`_schema.Column` and other objects explicitly set on the + table will replace corresponding reflected objects. + + .. deprecated:: 1.4 + + The autoload parameter is deprecated and will be removed in + version 2.0. Please use the + :paramref:`_schema.Table.autoload_with` parameter, passing an + engine or connection. + + .. seealso:: + + :ref:`metadata_reflection_toplevel` + + :param autoload_replace: Defaults to ``True``; when using + :paramref:`_schema.Table.autoload` + in conjunction with :paramref:`_schema.Table.extend_existing`, + indicates + that :class:`_schema.Column` objects present in the already-existing + :class:`_schema.Table` + object should be replaced with columns of the same + name retrieved from the autoload process. When ``False``, columns + already present under existing names will be omitted from the + reflection process. + + Note that this setting does not impact :class:`_schema.Column` objects + specified programmatically within the call to :class:`_schema.Table` + that + also is autoloading; those :class:`_schema.Column` objects will always + replace existing columns of the same name when + :paramref:`_schema.Table.extend_existing` is ``True``. + + .. seealso:: + + :paramref:`_schema.Table.autoload` + + :paramref:`_schema.Table.extend_existing` + + :param autoload_with: An :class:`_engine.Engine` or + :class:`_engine.Connection` object, + or a :class:`_reflection.Inspector` object as returned by + :func:`_sa.inspect` + against one, with which this :class:`_schema.Table` + object will be reflected. + When set to a non-None value, the autoload process will take place + for this table against the given engine or connection. + + :param extend_existing: When ``True``, indicates that if this + :class:`_schema.Table` is already present in the given + :class:`_schema.MetaData`, + apply further arguments within the constructor to the existing + :class:`_schema.Table`. + + If :paramref:`_schema.Table.extend_existing` or + :paramref:`_schema.Table.keep_existing` are not set, + and the given name + of the new :class:`_schema.Table` refers to a :class:`_schema.Table` + that is + already present in the target :class:`_schema.MetaData` collection, + and + this :class:`_schema.Table` + specifies additional columns or other constructs + or flags that modify the table's state, an + error is raised. The purpose of these two mutually-exclusive flags + is to specify what action should be taken when a + :class:`_schema.Table` + is specified that matches an existing :class:`_schema.Table`, + yet specifies + additional constructs. + + :paramref:`_schema.Table.extend_existing` + will also work in conjunction + with :paramref:`_schema.Table.autoload` to run a new reflection + operation against the database, even if a :class:`_schema.Table` + of the same name is already present in the target + :class:`_schema.MetaData`; newly reflected :class:`_schema.Column` + objects + and other options will be added into the state of the + :class:`_schema.Table`, potentially overwriting existing columns + and options of the same name. + + As is always the case with :paramref:`_schema.Table.autoload`, + :class:`_schema.Column` objects can be specified in the same + :class:`_schema.Table` + constructor, which will take precedence. Below, the existing + table ``mytable`` will be augmented with :class:`_schema.Column` + objects + both reflected from the database, as well as the given + :class:`_schema.Column` + named "y":: + + Table("mytable", metadata, + Column('y', Integer), + extend_existing=True, + autoload_with=engine + ) + + .. seealso:: + + :paramref:`_schema.Table.autoload` + + :paramref:`_schema.Table.autoload_replace` + + :paramref:`_schema.Table.keep_existing` + + + :param implicit_returning: True by default - indicates that + RETURNING can be used by default to fetch newly inserted primary key + values, for backends which support this. Note that + :func:`_sa.create_engine` also provides an ``implicit_returning`` + flag. + + :param include_columns: A list of strings indicating a subset of + columns to be loaded via the ``autoload`` operation; table columns who + aren't present in this list will not be represented on the resulting + ``Table`` object. Defaults to ``None`` which indicates all columns + should be reflected. + + :param resolve_fks: Whether or not to reflect :class:`_schema.Table` + objects + related to this one via :class:`_schema.ForeignKey` objects, when + :paramref:`_schema.Table.autoload` or + :paramref:`_schema.Table.autoload_with` is + specified. Defaults to True. Set to False to disable reflection of + related tables as :class:`_schema.ForeignKey` + objects are encountered; may be + used either to save on SQL calls or to avoid issues with related tables + that can't be accessed. Note that if a related table is already present + in the :class:`_schema.MetaData` collection, or becomes present later, + a + :class:`_schema.ForeignKey` object associated with this + :class:`_schema.Table` will + resolve to that table normally. + + .. versionadded:: 1.3 + + .. seealso:: + + :paramref:`.MetaData.reflect.resolve_fks` + + + :param info: Optional data dictionary which will be populated into the + :attr:`.SchemaItem.info` attribute of this object. + + :param keep_existing: When ``True``, indicates that if this Table + is already present in the given :class:`_schema.MetaData`, ignore + further arguments within the constructor to the existing + :class:`_schema.Table`, and return the :class:`_schema.Table` + object as + originally created. This is to allow a function that wishes + to define a new :class:`_schema.Table` on first call, but on + subsequent calls will return the same :class:`_schema.Table`, + without any of the declarations (particularly constraints) + being applied a second time. + + If :paramref:`_schema.Table.extend_existing` or + :paramref:`_schema.Table.keep_existing` are not set, + and the given name + of the new :class:`_schema.Table` refers to a :class:`_schema.Table` + that is + already present in the target :class:`_schema.MetaData` collection, + and + this :class:`_schema.Table` + specifies additional columns or other constructs + or flags that modify the table's state, an + error is raised. The purpose of these two mutually-exclusive flags + is to specify what action should be taken when a + :class:`_schema.Table` + is specified that matches an existing :class:`_schema.Table`, + yet specifies + additional constructs. + + .. seealso:: + + :paramref:`_schema.Table.extend_existing` + + :param listeners: A list of tuples of the form ``(<eventname>, <fn>)`` + which will be passed to :func:`.event.listen` upon construction. + This alternate hook to :func:`.event.listen` allows the establishment + of a listener function specific to this :class:`_schema.Table` before + the "autoload" process begins. Historically this has been intended + for use with the :meth:`.DDLEvents.column_reflect` event, however + note that this event hook may now be associated with the + :class:`_schema.MetaData` object directly:: + + def listen_for_reflect(table, column_info): + "handle the column reflection event" + # ... + + t = Table( + 'sometable', + autoload_with=engine, + listeners=[ + ('column_reflect', listen_for_reflect) + ]) + + .. seealso:: + + :meth:`_events.DDLEvents.column_reflect` + + :param must_exist: When ``True``, indicates that this Table must already + be present in the given :class:`_schema.MetaData` collection, else + an exception is raised. + + :param prefixes: + A list of strings to insert after CREATE in the CREATE TABLE + statement. They will be separated by spaces. + + :param quote: Force quoting of this table's name on or off, corresponding + to ``True`` or ``False``. When left at its default of ``None``, + the column identifier will be quoted according to whether the name is + case sensitive (identifiers with at least one upper case character are + treated as case sensitive), or if it's a reserved word. This flag + is only needed to force quoting of a reserved word which is not known + by the SQLAlchemy dialect. + + .. note:: setting this flag to ``False`` will not provide + case-insensitive behavior for table reflection; table reflection + will always search for a mixed-case name in a case sensitive + fashion. Case insensitive names are specified in SQLAlchemy only + by stating the name with all lower case characters. + + :param quote_schema: same as 'quote' but applies to the schema identifier. + + :param schema: The schema name for this table, which is required if + the table resides in a schema other than the default selected schema + for the engine's database connection. Defaults to ``None``. + + If the owning :class:`_schema.MetaData` of this :class:`_schema.Table` + specifies its + own :paramref:`_schema.MetaData.schema` parameter, + then that schema name will + be applied to this :class:`_schema.Table` + if the schema parameter here is set + to ``None``. To set a blank schema name on a :class:`_schema.Table` + that + would otherwise use the schema set on the owning + :class:`_schema.MetaData`, + specify the special symbol :attr:`.BLANK_SCHEMA`. + + .. versionadded:: 1.0.14 Added the :attr:`.BLANK_SCHEMA` symbol to + allow a :class:`_schema.Table` + to have a blank schema name even when the + parent :class:`_schema.MetaData` specifies + :paramref:`_schema.MetaData.schema`. + + The quoting rules for the schema name are the same as those for the + ``name`` parameter, in that quoting is applied for reserved words or + case-sensitive names; to enable unconditional quoting for the schema + name, specify the flag ``quote_schema=True`` to the constructor, or use + the :class:`.quoted_name` construct to specify the name. + + :param comment: Optional string that will render an SQL comment on table + creation. + + .. versionadded:: 1.2 Added the :paramref:`_schema.Table.comment` + parameter + to :class:`_schema.Table`. + + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form ``<dialectname>_<argname>``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ # noqa E501 + # __init__ is overridden to prevent __new__ from # calling the superclass constructor. @@ -1203,7 +1220,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): ) -> None: ... - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): r""" Construct a new ``Column`` object. @@ -2179,18 +2196,18 @@ class ForeignKey(DialectKWArgs, SchemaItem): def __init__( self, - column, - _constraint=None, - use_alter=False, - name=None, - onupdate=None, - ondelete=None, - deferrable=None, - initially=None, - link_to_name=False, - match=None, - info=None, - **dialect_kw, + column: Union[str, Column, SQLCoreOperations], + _constraint: Optional["ForeignKeyConstraint"] = None, + use_alter: bool = False, + name: Optional[str] = None, + onupdate: Optional[str] = None, + ondelete: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[bool] = None, + link_to_name: bool = False, + match: Optional[str] = None, + info: Optional[Dict[Any, Any]] = None, + **dialect_kw: Any, ): r""" Construct a column-level FOREIGN KEY. @@ -2337,7 +2354,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): ) return self._schema_item_copy(fk) - def _get_colspec(self, schema=None, table_name=None): + def _get_colspec(self, schema=None, table_name=None, _is_copy=False): """Return a string based 'column specification' for this :class:`_schema.ForeignKey`. @@ -2357,6 +2374,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): else: return "%s.%s" % (table_name, colname) elif self._table_column is not None: + if self._table_column.table is None: + if _is_copy: + raise exc.InvalidRequestError( + f"Can't copy ForeignKey object which refers to " + f"non-table bound Column {self._table_column!r}" + ) + else: + return self._table_column.key return "%s.%s" % ( self._table_column.table.fullname, self._table_column.key, @@ -3858,6 +3883,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): if target_table is not None and x._table_key() == x.parent.table.key else None, + _is_copy=True, ) for x in self.elements ], @@ -4331,10 +4357,10 @@ class MetaData(SchemaItem): def __init__( self, - schema=None, - quote_schema=None, - naming_convention=None, - info=None, + schema: Optional[str] = None, + quote_schema: Optional[bool] = None, + naming_convention: Optional[Dict[str, str]] = None, + info: Optional[Dict[Any, Any]] = None, ): """Create a new MetaData object. @@ -4465,7 +4491,7 @@ class MetaData(SchemaItem): self._sequences = {} self._fk_memos = collections.defaultdict(list) - tables = None + tables: Dict[str, Table] """A dictionary of :class:`_schema.Table` objects keyed to their name or "table key". @@ -4483,10 +4509,10 @@ class MetaData(SchemaItem): """ - def __repr__(self): + def __repr__(self) -> str: return "MetaData()" - def __contains__(self, table_or_key): + def __contains__(self, table_or_key: Union[str, Table]) -> bool: if not isinstance(table_or_key, str): table_or_key = table_or_key.key return table_or_key in self.tables @@ -4530,20 +4556,20 @@ class MetaData(SchemaItem): self._schemas = state["schemas"] self._fk_memos = state["fk_memos"] - def clear(self): + def clear(self) -> None: """Clear all Table objects from this MetaData.""" dict.clear(self.tables) self._schemas.clear() self._fk_memos.clear() - def remove(self, table): + def remove(self, table: Table) -> None: """Remove the given Table object from this MetaData.""" self._remove_table(table.name, table.schema) @property - def sorted_tables(self): + def sorted_tables(self) -> List[Table]: """Returns a list of :class:`_schema.Table` objects sorted in order of foreign key dependency. @@ -4599,14 +4625,14 @@ class MetaData(SchemaItem): def reflect( self, - bind, - schema=None, - views=False, - only=None, - extend_existing=False, - autoload_replace=True, - resolve_fks=True, - **dialect_kwargs, + bind: Union["Engine", "Connection"], + schema: Optional[str] = None, + views: bool = False, + only: Optional[_typing_Sequence[str]] = None, + extend_existing: bool = False, + autoload_replace: bool = True, + resolve_fks: bool = True, + **dialect_kwargs: Any, ): r"""Load all available table definitions from the database. @@ -4754,7 +4780,12 @@ class MetaData(SchemaItem): except exc.UnreflectableTableError as uerr: util.warn("Skipping table %s: %s" % (name, uerr)) - def create_all(self, bind, tables=None, checkfirst=True): + def create_all( + self, + bind: Union["Engine", "Connection"], + tables: Optional[_typing_Sequence[Table]] = None, + checkfirst: bool = True, + ): """Create all tables stored in this metadata. Conditional by default, will not attempt to recreate tables already @@ -4777,7 +4808,12 @@ class MetaData(SchemaItem): ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables ) - def drop_all(self, bind, tables=None, checkfirst=True): + def drop_all( + self, + bind: Union["Engine", "Connection"], + tables: Optional[_typing_Sequence[Table]] = None, + checkfirst: bool = True, + ): """Drop all tables stored in this metadata. Conditional by default, will not attempt to drop tables not present in diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e1bbcffec..b0985f75d 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -12,14 +12,13 @@ SQL tables and derived rowsets. """ import collections +from enum import Enum import itertools from operator import attrgetter import typing from typing import Any as TODO_Any from typing import Optional from typing import Tuple -from typing import Type -from typing import Union from . import cache_key from . import coercions @@ -28,6 +27,7 @@ from . import roles from . import traversals from . import type_api from . import visitors +from ._typing import _ColumnsClauseElement from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -847,8 +847,11 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.alias(name=name) -LABEL_STYLE_NONE = util.symbol( - "LABEL_STYLE_NONE", +class SelectLabelStyle(Enum): + """Label style constants that may be passed to + :meth:`_sql.Select.set_label_style`.""" + + LABEL_STYLE_NONE = 0 """Label style indicating no automatic labeling should be applied to the columns clause of a SELECT statement. @@ -867,11 +870,9 @@ LABEL_STYLE_NONE = util.symbol( .. versionadded:: 1.4 -""", # noqa E501 -) + """ # noqa E501 -LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol( - "LABEL_STYLE_TABLENAME_PLUS_COL", + LABEL_STYLE_TABLENAME_PLUS_COL = 1 """Label style indicating all columns should be labeled as ``<tablename>_<columnname>`` when generating the columns clause of a SELECT statement, to disambiguate same-named columns referenced from different @@ -897,12 +898,9 @@ LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol( .. versionadded:: 1.4 -""", # noqa E501 -) + """ # noqa: E501 - -LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol( - "LABEL_STYLE_DISAMBIGUATE_ONLY", + LABEL_STYLE_DISAMBIGUATE_ONLY = 2 """Label style indicating that columns with a name that conflicts with an existing name should be labeled with a semi-anonymizing label when generating the columns clause of a SELECT statement. @@ -924,17 +922,24 @@ LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol( .. versionadded:: 1.4 -""", # noqa: E501, -) + """ # noqa: E501 + LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY + """The default label style, refers to + :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`. -LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY -"""The default label style, refers to -:data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`. + .. versionadded:: 1.4 -.. versionadded:: 1.4 + """ -""" + +( + LABEL_STYLE_NONE, + LABEL_STYLE_TABLENAME_PLUS_COL, + LABEL_STYLE_DISAMBIGUATE_ONLY, +) = list(SelectLabelStyle) + +LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY class Join(roles.DMLTableRole, FromClause): @@ -2870,10 +2875,12 @@ class SelectStatementGrouping(GroupedElement, SelectBase): else: return self - def get_label_style(self): + def get_label_style(self) -> SelectLabelStyle: return self._label_style - def set_label_style(self, label_style): + def set_label_style( + self, label_style: SelectLabelStyle + ) -> "SelectStatementGrouping": return SelectStatementGrouping( self.element.set_label_style(label_style) ) @@ -3018,7 +3025,7 @@ class GenerativeSelect(SelectBase): ) return self - def get_label_style(self): + def get_label_style(self) -> SelectLabelStyle: """ Retrieve the current label style. @@ -3027,14 +3034,16 @@ class GenerativeSelect(SelectBase): """ return self._label_style - def set_label_style(self, style): + def set_label_style( + self: SelfGenerativeSelect, style: SelectLabelStyle + ) -> SelfGenerativeSelect: """Return a new selectable with the specified label style. There are three "label styles" available, - :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`, - :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`, and - :data:`_sql.LABEL_STYLE_NONE`. The default style is - :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`. + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY`, + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL`, and + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_NONE`. The default style is + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL`. In modern SQLAlchemy, there is not generally a need to change the labeling style, as per-expression labels are more effectively used by @@ -4131,7 +4140,7 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: Union[roles.ColumnsClauseRole, Type]): + def __init__(self, *entities: _ColumnsClauseElement): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index dd29b2c3a..6b878dc70 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -13,6 +13,7 @@ import typing from typing import Any from typing import Callable from typing import Generic +from typing import Optional from typing import Tuple from typing import Type from typing import TypeVar @@ -21,7 +22,7 @@ from typing import Union from .base import SchemaEventTarget from .cache_key import NO_CACHE from .operators import ColumnOperators -from .visitors import Traversible +from .visitors import Visitable from .. import exc from .. import util @@ -52,7 +53,7 @@ _CT = TypeVar("_CT", bound=Any) SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine") -class TypeEngine(Traversible, Generic[_T]): +class TypeEngine(Visitable, Generic[_T]): """The ultimate base class for all SQL datatypes. Common subclasses of :class:`.TypeEngine` include @@ -573,7 +574,7 @@ class TypeEngine(Traversible, Generic[_T]): raise NotImplementedError() def with_variant( - self: SelfTypeEngine, type_: "TypeEngine", dialect_name: str + self: SelfTypeEngine, type_: "TypeEngine", *dialect_names: str ) -> SelfTypeEngine: r"""Produce a copy of this type object that will utilize the given type when applied to the dialect of the given name. @@ -586,7 +587,7 @@ class TypeEngine(Traversible, Generic[_T]): string_type = String() string_type = string_type.with_variant( - mysql.VARCHAR(collation='foo'), 'mysql' + mysql.VARCHAR(collation='foo'), 'mysql', 'mariadb' ) The variant mapping indicates that when this type is @@ -602,16 +603,20 @@ class TypeEngine(Traversible, Generic[_T]): :param type\_: a :class:`.TypeEngine` that will be selected as a variant from the originating type, when a dialect of the given name is in use. - :param dialect_name: base name of the dialect which uses - this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) + :param \*dialect_names: one or more base names of the dialect which + uses this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) + + .. versionchanged:: 2.0 multiple dialect names can be specified + for one variant. """ - if dialect_name in self._variant_mapping: - raise exc.ArgumentError( - "Dialect '%s' is already present in " - "the mapping for this %r" % (dialect_name, self) - ) + for dialect_name in dialect_names: + if dialect_name in self._variant_mapping: + raise exc.ArgumentError( + "Dialect '%s' is already present in " + "the mapping for this %r" % (dialect_name, self) + ) new_type = self.copy() if isinstance(type_, type): type_ = type_() @@ -620,8 +625,9 @@ class TypeEngine(Traversible, Generic[_T]): "can't pass a type that already has variants as a " "dialect-level type to with_variant()" ) + new_type._variant_mapping = self._variant_mapping.union( - {dialect_name: type_} + {dialect_name: type_ for dialect_name in dialect_names} ) return new_type @@ -919,7 +925,7 @@ class ExternalType: """ - cache_ok = None + cache_ok: Optional[bool] = None """Indicate if statements using this :class:`.ExternalType` are "safe to cache". @@ -1357,6 +1363,8 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): _is_type_decorator = True + impl: Union[TypeEngine[Any], Type[TypeEngine[Any]]] + def __init__(self, *args, **kwargs): """Construct a :class:`.TypeDecorator`. diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 268a56421..c1ca670da 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -6,6 +6,11 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php import collections +import typing +from typing import Any +from typing import Iterable +from typing import Tuple +from typing import Union from .. import util @@ -20,10 +25,15 @@ any_async = False _current = None ident = "main" -_fixture_functions = None # installed by plugin_base +if typing.TYPE_CHECKING: + from .plugin.plugin_base import FixtureFunctions + _fixture_functions: FixtureFunctions +else: + _fixture_functions = None # installed by plugin_base -def combinations(*comb, **kw): + +def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str): r"""Deliver multiple versions of a test based on positional combinations. This is a facade over pytest.mark.parametrize. @@ -89,25 +99,32 @@ def combinations(*comb, **kw): return _fixture_functions.combinations(*comb, **kw) -def combinations_list(arg_iterable, **kw): +def combinations_list( + arg_iterable: Iterable[ + Tuple[ + Any, + ] + ], + **kw, +): "As combination, but takes a single iterable" return combinations(*arg_iterable, **kw) -def fixture(*arg, **kw): +def fixture(*arg: Any, **kw: Any) -> Any: return _fixture_functions.fixture(*arg, **kw) -def get_current_test_name(): +def get_current_test_name() -> str: return _fixture_functions.get_current_test_name() -def mark_base_test_class(): +def mark_base_test_class() -> Any: return _fixture_functions.mark_base_test_class() class _AddToMarker: - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: return getattr(_fixture_functions.add_to_marker, attr) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ecc20f163..7228e5afe 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -20,6 +20,7 @@ from .util import drop_all_tables_from_metadata from .. import event from .. import util from ..orm import declarative_base +from ..orm import DeclarativeBase from ..orm import registry from ..schema import sort_tables_and_constraints @@ -82,6 +83,21 @@ class TestBase: yield reg reg.dispose() + @config.fixture + def decl_base(self, metadata): + _md = metadata + + class Base(DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + @config.fixture() def future_connection(self, future_engine, connection): # integrate the future_engine and connection fixtures so diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 0b4451b3c..52e42bb97 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -19,6 +19,7 @@ import logging import os import re import sys +from typing import Any from sqlalchemy.testing import asyncio @@ -738,7 +739,7 @@ class FixtureFunctions(abc.ABC): raise NotImplementedError() @abc.abstractmethod - def mark_base_test_class(self): + def mark_base_test_class(self) -> Any: raise NotImplementedError() @abc.abstractproperty diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 410ab26ed..41e5d6772 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1326,6 +1326,18 @@ class SuiteRequirements(Requirements): return exclusions.only_if(check) @property + def no_sqlalchemy2_stubs(self): + def check(config): + try: + __import__("sqlalchemy-stubs.ext.mypy") + except ImportError: + return False + else: + return True + + return exclusions.skip_if(check) + + @property def python38(self): return exclusions.only_if( lambda: util.py38, "Python 3.8 or above required" diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 91d15aae0..85bbca20f 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -6,131 +6,135 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php -from collections import defaultdict -from functools import partial -from functools import update_wrapper +from collections import defaultdict as defaultdict +from functools import partial as partial +from functools import update_wrapper as update_wrapper -from ._collections import coerce_generator_arg -from ._collections import coerce_to_immutabledict -from ._collections import column_dict -from ._collections import column_set -from ._collections import EMPTY_DICT -from ._collections import EMPTY_SET -from ._collections import FacadeDict -from ._collections import flatten_iterator -from ._collections import has_dupes -from ._collections import has_intersection -from ._collections import IdentitySet -from ._collections import ImmutableContainer -from ._collections import immutabledict -from ._collections import ImmutableProperties -from ._collections import LRUCache -from ._collections import ordered_column_set -from ._collections import OrderedDict -from ._collections import OrderedIdentitySet -from ._collections import OrderedProperties -from ._collections import OrderedSet -from ._collections import PopulateDict -from ._collections import Properties -from ._collections import ScopedRegistry -from ._collections import sort_dictionary -from ._collections import ThreadLocalRegistry -from ._collections import to_column_set -from ._collections import to_list -from ._collections import to_set -from ._collections import unique_list -from ._collections import UniqueAppender -from ._collections import update_copy -from ._collections import WeakPopulateDict -from ._collections import WeakSequence -from ._preloaded import preload_module -from ._preloaded import preloaded -from .compat import arm -from .compat import b -from .compat import b64decode -from .compat import b64encode -from .compat import cmp -from .compat import cpython -from .compat import dataclass_fields -from .compat import decode_backslashreplace -from .compat import dottedgetter -from .compat import has_refcount_gc -from .compat import inspect_getfullargspec -from .compat import local_dataclass_fields -from .compat import next -from .compat import osx -from .compat import py38 -from .compat import py39 -from .compat import pypy -from .compat import win32 -from .concurrency import asyncio -from .concurrency import await_fallback -from .concurrency import await_only -from .concurrency import greenlet_spawn -from .concurrency import is_exit_exception -from .deprecations import became_legacy_20 -from .deprecations import deprecated -from .deprecations import deprecated_cls -from .deprecations import deprecated_params -from .deprecations import deprecated_property -from .deprecations import inject_docstring_text -from .deprecations import moved_20 -from .deprecations import warn_deprecated -from .langhelpers import add_parameter_text -from .langhelpers import as_interface -from .langhelpers import asbool -from .langhelpers import asint -from .langhelpers import assert_arg_type -from .langhelpers import attrsetter -from .langhelpers import bool_or_str -from .langhelpers import chop_traceback -from .langhelpers import class_hierarchy -from .langhelpers import classproperty -from .langhelpers import clsname_as_plain_name -from .langhelpers import coerce_kw_type -from .langhelpers import constructor_copy -from .langhelpers import constructor_key -from .langhelpers import counter -from .langhelpers import create_proxy_methods -from .langhelpers import decode_slice -from .langhelpers import decorator -from .langhelpers import dictlike_iteritems -from .langhelpers import duck_type_collection -from .langhelpers import ellipses_string -from .langhelpers import EnsureKWArg -from .langhelpers import format_argspec_init -from .langhelpers import format_argspec_plus -from .langhelpers import generic_repr -from .langhelpers import get_callable_argspec -from .langhelpers import get_cls_kwargs -from .langhelpers import get_func_kwargs -from .langhelpers import getargspec_init -from .langhelpers import has_compiled_ext -from .langhelpers import HasMemoized -from .langhelpers import hybridmethod -from .langhelpers import hybridproperty -from .langhelpers import iterate_attributes -from .langhelpers import map_bits -from .langhelpers import md5_hex -from .langhelpers import memoized_instancemethod -from .langhelpers import memoized_property -from .langhelpers import MemoizedSlots -from .langhelpers import method_is_overridden -from .langhelpers import methods_equivalent -from .langhelpers import monkeypatch_proxied_specials -from .langhelpers import NoneType -from .langhelpers import only_once -from .langhelpers import PluginLoader -from .langhelpers import portable_instancemethod -from .langhelpers import quoted_token_parser -from .langhelpers import safe_reraise -from .langhelpers import set_creation_order -from .langhelpers import string_or_unprintable -from .langhelpers import symbol -from .langhelpers import TypingOnly -from .langhelpers import unbound_method_to_callable -from .langhelpers import walk_subclasses -from .langhelpers import warn -from .langhelpers import warn_exception -from .langhelpers import warn_limited -from .langhelpers import wrap_callable +from ._collections import coerce_generator_arg as coerce_generator_arg +from ._collections import coerce_to_immutabledict as coerce_to_immutabledict +from ._collections import column_dict as column_dict +from ._collections import column_set as column_set +from ._collections import EMPTY_DICT as EMPTY_DICT +from ._collections import EMPTY_SET as EMPTY_SET +from ._collections import FacadeDict as FacadeDict +from ._collections import flatten_iterator as flatten_iterator +from ._collections import has_dupes as has_dupes +from ._collections import has_intersection as has_intersection +from ._collections import IdentitySet as IdentitySet +from ._collections import ImmutableContainer as ImmutableContainer +from ._collections import immutabledict as immutabledict +from ._collections import ImmutableProperties as ImmutableProperties +from ._collections import LRUCache as LRUCache +from ._collections import merge_lists_w_ordering as merge_lists_w_ordering +from ._collections import ordered_column_set as ordered_column_set +from ._collections import OrderedDict as OrderedDict +from ._collections import OrderedIdentitySet as OrderedIdentitySet +from ._collections import OrderedProperties as OrderedProperties +from ._collections import OrderedSet as OrderedSet +from ._collections import PopulateDict as PopulateDict +from ._collections import Properties as Properties +from ._collections import ScopedRegistry as ScopedRegistry +from ._collections import sort_dictionary as sort_dictionary +from ._collections import ThreadLocalRegistry as ThreadLocalRegistry +from ._collections import to_column_set as to_column_set +from ._collections import to_list as to_list +from ._collections import to_set as to_set +from ._collections import unique_list as unique_list +from ._collections import UniqueAppender as UniqueAppender +from ._collections import update_copy as update_copy +from ._collections import WeakPopulateDict as WeakPopulateDict +from ._collections import WeakSequence as WeakSequence +from ._preloaded import preload_module as preload_module +from ._preloaded import preloaded as preloaded +from .compat import arm as arm +from .compat import b as b +from .compat import b64decode as b64decode +from .compat import b64encode as b64encode +from .compat import cmp as cmp +from .compat import cpython as cpython +from .compat import dataclass_fields as dataclass_fields +from .compat import decode_backslashreplace as decode_backslashreplace +from .compat import dottedgetter as dottedgetter +from .compat import has_refcount_gc as has_refcount_gc +from .compat import inspect_getfullargspec as inspect_getfullargspec +from .compat import local_dataclass_fields as local_dataclass_fields +from .compat import osx as osx +from .compat import py38 as py38 +from .compat import py39 as py39 +from .compat import pypy as pypy +from .compat import win32 as win32 +from .concurrency import await_fallback as await_fallback +from .concurrency import await_only as await_only +from .concurrency import greenlet_spawn as greenlet_spawn +from .concurrency import is_exit_exception as is_exit_exception +from .deprecations import became_legacy_20 as became_legacy_20 +from .deprecations import deprecated as deprecated +from .deprecations import deprecated_cls as deprecated_cls +from .deprecations import deprecated_params as deprecated_params +from .deprecations import deprecated_property as deprecated_property +from .deprecations import moved_20 as moved_20 +from .deprecations import warn_deprecated as warn_deprecated +from .langhelpers import add_parameter_text as add_parameter_text +from .langhelpers import as_interface as as_interface +from .langhelpers import asbool as asbool +from .langhelpers import asint as asint +from .langhelpers import assert_arg_type as assert_arg_type +from .langhelpers import attrsetter as attrsetter +from .langhelpers import bool_or_str as bool_or_str +from .langhelpers import chop_traceback as chop_traceback +from .langhelpers import class_hierarchy as class_hierarchy +from .langhelpers import classproperty as classproperty +from .langhelpers import clsname_as_plain_name as clsname_as_plain_name +from .langhelpers import coerce_kw_type as coerce_kw_type +from .langhelpers import constructor_copy as constructor_copy +from .langhelpers import constructor_key as constructor_key +from .langhelpers import counter as counter +from .langhelpers import create_proxy_methods as create_proxy_methods +from .langhelpers import decode_slice as decode_slice +from .langhelpers import decorator as decorator +from .langhelpers import dictlike_iteritems as dictlike_iteritems +from .langhelpers import duck_type_collection as duck_type_collection +from .langhelpers import ellipses_string as ellipses_string +from .langhelpers import EnsureKWArg as EnsureKWArg +from .langhelpers import format_argspec_init as format_argspec_init +from .langhelpers import format_argspec_plus as format_argspec_plus +from .langhelpers import generic_repr as generic_repr +from .langhelpers import get_annotations as get_annotations +from .langhelpers import get_callable_argspec as get_callable_argspec +from .langhelpers import get_cls_kwargs as get_cls_kwargs +from .langhelpers import get_func_kwargs as get_func_kwargs +from .langhelpers import getargspec_init as getargspec_init +from .langhelpers import has_compiled_ext as has_compiled_ext +from .langhelpers import HasMemoized as HasMemoized +from .langhelpers import hybridmethod as hybridmethod +from .langhelpers import hybridproperty as hybridproperty +from .langhelpers import inject_docstring_text as inject_docstring_text +from .langhelpers import iterate_attributes as iterate_attributes +from .langhelpers import map_bits as map_bits +from .langhelpers import md5_hex as md5_hex +from .langhelpers import memoized_instancemethod as memoized_instancemethod +from .langhelpers import memoized_property as memoized_property +from .langhelpers import MemoizedSlots as MemoizedSlots +from .langhelpers import method_is_overridden as method_is_overridden +from .langhelpers import methods_equivalent as methods_equivalent +from .langhelpers import ( + monkeypatch_proxied_specials as monkeypatch_proxied_specials, +) +from .langhelpers import NoneType as NoneType +from .langhelpers import only_once as only_once +from .langhelpers import PluginLoader as PluginLoader +from .langhelpers import portable_instancemethod as portable_instancemethod +from .langhelpers import quoted_token_parser as quoted_token_parser +from .langhelpers import safe_reraise as safe_reraise +from .langhelpers import set_creation_order as set_creation_order +from .langhelpers import string_or_unprintable as string_or_unprintable +from .langhelpers import symbol as symbol +from .langhelpers import TypingOnly as TypingOnly +from .langhelpers import ( + unbound_method_to_callable as unbound_method_to_callable, +) +from .langhelpers import walk_subclasses as walk_subclasses +from .langhelpers import warn as warn +from .langhelpers import warn_exception as warn_exception +from .langhelpers import warn_limited as warn_limited +from .langhelpers import wrap_callable as wrap_callable diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 3e4ef1310..850986802 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -34,19 +34,27 @@ from ._has_cy import HAS_CYEXTENSION from .typing import Literal if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_collections import immutabledict - from ._py_collections import IdentitySet - from ._py_collections import ImmutableContainer - from ._py_collections import ImmutableDictBase - from ._py_collections import OrderedSet - from ._py_collections import unique_list # noqa + from ._py_collections import immutabledict as immutabledict + from ._py_collections import IdentitySet as IdentitySet + from ._py_collections import ImmutableContainer as ImmutableContainer + from ._py_collections import ImmutableDictBase as ImmutableDictBase + from ._py_collections import OrderedSet as OrderedSet + from ._py_collections import unique_list as unique_list else: - from sqlalchemy.cyextension.immutabledict import ImmutableContainer - from sqlalchemy.cyextension.immutabledict import ImmutableDictBase - from sqlalchemy.cyextension.immutabledict import immutabledict - from sqlalchemy.cyextension.collections import IdentitySet - from sqlalchemy.cyextension.collections import OrderedSet - from sqlalchemy.cyextension.collections import unique_list # noqa + from sqlalchemy.cyextension.immutabledict import ( + ImmutableContainer as ImmutableContainer, + ) + from sqlalchemy.cyextension.immutabledict import ( + ImmutableDictBase as ImmutableDictBase, + ) + from sqlalchemy.cyextension.immutabledict import ( + immutabledict as immutabledict, + ) + from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet + from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet + from sqlalchemy.cyextension.collections import ( # noqa + unique_list as unique_list, + ) _T = TypeVar("_T", bound=Any) @@ -57,6 +65,62 @@ _VT = TypeVar("_VT", bound=Any) EMPTY_SET: FrozenSet[Any] = frozenset() +def merge_lists_w_ordering(a, b): + """merge two lists, maintaining ordering as much as possible. + + this is to reconcile vars(cls) with cls.__annotations__. + + Example:: + + >>> a = ['__tablename__', 'id', 'x', 'created_at'] + >>> b = ['id', 'name', 'data', 'y', 'created_at'] + >>> merge_lists_w_ordering(a, b) + ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] + + This is not necessarily the ordering that things had on the class, + in this case the class is:: + + class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + data: Mapped[Optional[str]] + x = Column(Integer) + y: Mapped[int] + created_at: Mapped[datetime.datetime] = mapped_column() + + But things are *mostly* ordered. + + The algorithm could also be done by creating a partial ordering for + all items in both lists and then using topological_sort(), but that + is too much overhead. + + Background on how I came up with this is at: + https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae + + """ + overlap = set(a).intersection(b) + + result = [] + + current, other = iter(a), iter(b) + + while True: + for element in current: + if element in overlap: + overlap.discard(element) + other, current = current, other + break + + result.append(element) + else: + result.extend(other) + break + + return result + + def coerce_to_immutabledict(d): if not d: return EMPTY_DICT diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 0f4befbb1..62cffa556 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -39,7 +39,6 @@ arm = "aarch" in platform.machine().lower() has_refcount_gc = bool(cpython) dottedgetter = operator.attrgetter -next = next # noqa class FullArgSpec(typing.NamedTuple): diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 57ef23006..6b94a2294 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -16,15 +16,17 @@ except ImportError as e: pass else: have_greenlet = True - from ._concurrency_py3k import await_only - from ._concurrency_py3k import await_fallback - from ._concurrency_py3k import greenlet_spawn - from ._concurrency_py3k import is_exit_exception - from ._concurrency_py3k import AsyncAdaptedLock - from ._concurrency_py3k import _util_async_run # noqa F401 + from ._concurrency_py3k import await_only as await_only + from ._concurrency_py3k import await_fallback as await_fallback + from ._concurrency_py3k import greenlet_spawn as greenlet_spawn + from ._concurrency_py3k import is_exit_exception as is_exit_exception + from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock from ._concurrency_py3k import ( - _util_async_run_coroutine_function, - ) # noqa F401, E501 + _util_async_run as _util_async_run, + ) # noqa F401 + from ._concurrency_py3k import ( + _util_async_run_coroutine_function as _util_async_run_coroutine_function, # noqa F401, E501 + ) if not have_greenlet: diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 565cbafe2..7c2586166 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -13,6 +13,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Optional +from typing import Tuple from typing import TypeVar from . import compat @@ -209,7 +210,10 @@ def became_legacy_20(api_name, alternative=None, **kw): return deprecated("2.0", message=message, warning=warning_cls, **kw) -def deprecated_params(**specs): +_C = TypeVar("_C", bound=Callable[..., Any]) + + +def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: """Decorates a function to warn on use of certain parameters. e.g. :: diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 9401c249f..ed879894d 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -30,6 +30,7 @@ from typing import FrozenSet from typing import Generic from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence @@ -54,6 +55,30 @@ _HP = TypeVar("_HP", bound="hybridproperty") _HM = TypeVar("_HM", bound="hybridmethod") +if compat.py310: + + def get_annotations(obj: Any) -> Mapping[str, Any]: + return inspect.get_annotations(obj) + +else: + + def get_annotations(obj: Any) -> Mapping[str, Any]: + # it's been observed that cls.__annotations__ can be non present. + # it's not clear what causes this, running under tox py37/38 it + # happens, running straight pytest it doesnt + + # https://docs.python.org/3/howto/annotations.html#annotations-howto + if isinstance(obj, type): + ann = obj.__dict__.get("__annotations__", None) + else: + ann = getattr(obj, "__annotations__", None) + + if ann is None: + return _collections.EMPTY_DICT + else: + return cast("Mapping[str, Any]", ann) + + def md5_hex(x: Any) -> str: x = x.encode("utf-8") m = hashlib.md5() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 62a9f6c8a..56ea4d0e0 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -1,6 +1,10 @@ +import sys import typing from typing import Any from typing import Callable # noqa +from typing import cast +from typing import Dict +from typing import ForwardRef from typing import Generic from typing import overload from typing import Type @@ -13,21 +17,36 @@ from . import compat _T = TypeVar("_T", bound=Any) -if typing.TYPE_CHECKING or not compat.py38: - from typing_extensions import Literal # noqa F401 - from typing_extensions import Protocol # noqa F401 - from typing_extensions import TypedDict # noqa F401 +if compat.py310: + # why they took until py310 to put this in stdlib is beyond me, + # I've been wanting it since py27 + from types import NoneType else: - from typing import Literal # noqa F401 - from typing import Protocol # noqa F401 - from typing import TypedDict # noqa F401 + NoneType = type(None) # type: ignore + +if typing.TYPE_CHECKING or compat.py310: + from typing import Annotated as Annotated +else: + from typing_extensions import Annotated as Annotated # noqa F401 + +if typing.TYPE_CHECKING or compat.py38: + from typing import Literal as Literal + from typing import Protocol as Protocol + from typing import TypedDict as TypedDict +else: + from typing_extensions import Literal as Literal # noqa F401 + from typing_extensions import Protocol as Protocol # noqa F401 + from typing_extensions import TypedDict as TypedDict # noqa F401 + +# work around https://github.com/microsoft/pyright/issues/3025 +_LiteralStar = Literal["*"] if typing.TYPE_CHECKING or not compat.py310: - from typing_extensions import Concatenate # noqa F401 - from typing_extensions import ParamSpec # noqa F401 + from typing_extensions import Concatenate as Concatenate + from typing_extensions import ParamSpec as ParamSpec else: - from typing import Concatenate # noqa F401 - from typing import ParamSpec # noqa F401 + from typing import Concatenate as Concatenate # noqa F401 + from typing import ParamSpec as ParamSpec # noqa F401 class _TypeToInstance(Generic[_T]): @@ -76,3 +95,121 @@ class ReadOnlyInstanceDescriptor(Protocol[_T]): self, instance: object, owner: Any ) -> Union["ReadOnlyInstanceDescriptor[_T]", _T]: ... + + +def de_stringify_annotation( + cls: Type[Any], annotation: Union[str, Type[Any]] +) -> Union[str, Type[Any]]: + """Resolve annotations that may be string based into real objects. + + This is particularly important if a module defines "from __future__ import + annotations", as everything inside of __annotations__ is a string. We want + to at least have generic containers like ``Mapped``, ``Union``, ``List``, + etc. + + """ + + # looked at typing.get_type_hints(), looked at pydantic. We need much + # less here, and we here try to not use any private typing internals + # or construct ForwardRef objects which is documented as something + # that should be avoided. + + if ( + is_fwd_ref(annotation) + and not cast(ForwardRef, annotation).__forward_evaluated__ + ): + annotation = cast(ForwardRef, annotation).__forward_arg__ + + if isinstance(annotation, str): + base_globals: "Dict[str, Any]" = getattr( + sys.modules.get(cls.__module__, None), "__dict__", {} + ) + try: + annotation = eval(annotation, base_globals, None) + except NameError: + pass + return annotation + + +def is_fwd_ref(type_): + return isinstance(type_, ForwardRef) + + +def de_optionalize_union_types(type_): + """Given a type, filter out ``Union`` types that include ``NoneType`` + to not include the ``NoneType``. + + """ + if is_optional(type_): + typ = set(type_.__args__) + + typ.discard(NoneType) + + return make_union_type(*typ) + + else: + return type_ + + +def make_union_type(*types): + """Make a Union type. + + This is needed by :func:`.de_optionalize_union_types` which removes + ``NoneType`` from a ``Union``. + + """ + return cast(Any, Union).__getitem__(types) + + +def expand_unions(type_, include_union=False, discard_none=False): + """Return a type as as a tuple of individual types, expanding for + ``Union`` types.""" + + if is_union(type_): + typ = set(type_.__args__) + + if discard_none: + typ.discard(NoneType) + + if include_union: + return (type_,) + tuple(typ) + else: + return tuple(typ) + else: + return (type_,) + + +def is_optional(type_): + return is_origin_of( + type_, + "Optional", + "Union", + ) + + +def is_union(type_): + return is_origin_of(type_, "Union") + + +def is_origin_of(type_, *names, module=None): + """return True if the given type has an __origin__ with the given name + and optional module.""" + + origin = getattr(type_, "__origin__", None) + if origin is None: + return False + + return _get_type_name(origin) in names and ( + module is None or origin.__module__.startswith(module) + ) + + +def _get_type_name(type_): + if compat.py310: + return type_.__name__ + else: + typ_name = getattr(type_, "__name__", None) + if typ_name is None: + typ_name = getattr(type_, "_name", None) + + return typ_name diff --git a/mypy_plugin.ini b/mypy_plugin.ini new file mode 100644 index 000000000..34ddc371c --- /dev/null +++ b/mypy_plugin.ini @@ -0,0 +1,9 @@ +[mypy] +plugins = sqlalchemy.ext.mypy.plugin +show_error_codes = True +mypy_path=./lib/ +strict = True +raise_exceptions=True + +[mypy-sqlalchemy.*] +ignore_errors = True diff --git a/pyproject.toml b/pyproject.toml index 3af6ea089..be5dd1596 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,10 +65,9 @@ warn_unused_ignores = false strict = true -# https://github.com/python/mypy/issues/8754 -# we are a pep-561 package, so implicit-rexport should be -# enabled -implicit_reexport = true +# some debate at +# https://github.com/python/mypy/issues/8754. +# implicit_reexport = true # individual packages or even modules should be listed here # with strictness-specificity set up. there's no way we are going to get @@ -79,7 +78,6 @@ implicit_reexport = true [[tool.mypy.overrides]] module = [ "sqlalchemy.events", - "sqlalchemy.events", "sqlalchemy.exc", "sqlalchemy.inspection", "sqlalchemy.schema", @@ -109,6 +109,8 @@ import-order-style = google application-import-names = sqlalchemy,test per-file-ignores = **/__init__.py:F401 + test/ext/mypy/plain_files/*:F821,E501 + test/ext/mypy/plugin_files/*:F821,E501 lib/sqlalchemy/events.py:F401 lib/sqlalchemy/schema.py:F401 lib/sqlalchemy/types.py:F401 diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py index 3c89108ee..79601019e 100644 --- a/test/base/test_concurrency_py3k.py +++ b/test/base/test_concurrency_py3k.py @@ -1,3 +1,4 @@ +import asyncio import threading from sqlalchemy import exc @@ -7,7 +8,6 @@ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_true -from sqlalchemy.util import asyncio from sqlalchemy.util import await_fallback from sqlalchemy.util import await_only from sqlalchemy.util import greenlet_spawn diff --git a/test/base/test_utils.py b/test/base/test_utils.py index dc02c37cb..67fcc8870 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -31,6 +31,7 @@ from sqlalchemy.util import compat from sqlalchemy.util import get_callable_argspec from sqlalchemy.util import langhelpers from sqlalchemy.util import WeakSequence +from sqlalchemy.util._collections import merge_lists_w_ordering class WeakSequenceTest(fixtures.TestBase): @@ -66,6 +67,49 @@ class WeakSequenceTest(fixtures.TestBase): eq_(len(w._storage), 2) +class MergeListsWOrderingTest(fixtures.TestBase): + @testing.combinations( + ( + ["__tablename__", "id", "x", "created_at"], + ["id", "name", "data", "y", "created_at"], + ["__tablename__", "id", "name", "data", "y", "x", "created_at"], + ), + (["a", "b", "c", "d", "e", "f"], [], ["a", "b", "c", "d", "e", "f"]), + ([], ["a", "b", "c", "d", "e", "f"], ["a", "b", "c", "d", "e", "f"]), + ([], [], []), + (["a", "b", "c"], ["a", "b", "c"], ["a", "b", "c"]), + ( + ["a", "b", "c"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d", "e"], + ), + (["a", "b", "c", "d"], ["c", "d", "e"], ["a", "b", "c", "d", "e"]), + ( + ["a", "c", "e", "g"], + ["b", "d", "f", "g"], + ["a", "c", "e", "b", "d", "f", "g"], # no overlaps until "g" + ), + ( + ["a", "b", "e", "f", "g"], + ["b", "c", "d", "e"], + ["a", "b", "c", "d", "e", "f", "g"], + ), + ( + ["a", "b", "c", "e", "f"], + ["c", "d", "f", "g"], + ["a", "b", "c", "d", "e", "f", "g"], + ), + ( + ["c", "d", "f", "g"], + ["a", "b", "c", "e", "f"], + ["a", "b", "c", "e", "d", "f", "g"], + ), + argnames="a,b,expected", + ) + def test_merge_lists(self, a, b, expected): + eq_(merge_lists_w_ordering(a, b), expected) + + class OrderedDictTest(fixtures.TestBase): def test_odict(self): o = util.OrderedDict() diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 5f33fa46d..613fc80a5 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -29,7 +29,6 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import TypeDecorator -from sqlalchemy import util from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import JSONB @@ -615,7 +614,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): "id", Integer, primary_key=True, - default=lambda: util.next(counter), + default=lambda: next(counter), ), Column("data", Integer), ) diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index 64c32c76b..a695aadba 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -547,9 +547,9 @@ class ConcreteInhTest( configure_mappers() self.assert_compile( select(Employee), - "SELECT pjoin.name, pjoin.employee_id, pjoin.type, pjoin._type " - "FROM (SELECT manager.name AS name, manager.employee_id AS " - "employee_id, manager.type AS type, 'manager' AS _type " + "SELECT pjoin.employee_id, pjoin.type, pjoin.name, pjoin._type " + "FROM (SELECT manager.employee_id AS employee_id, " + "manager.type AS type, manager.name AS name, 'manager' AS _type " "FROM manager) AS pjoin", ) @@ -859,13 +859,13 @@ class ConcreteExtensionConfigTest( session = Session() self.assert_compile( session.query(Document), - "SELECT pjoin.doctype AS pjoin_doctype, " - "pjoin.send_method AS pjoin_send_method, " - "pjoin.id AS pjoin_id, pjoin.type AS pjoin_type " - "FROM (SELECT actual_documents.doctype AS doctype, " + "SELECT pjoin.id AS pjoin_id, pjoin.send_method AS " + "pjoin_send_method, pjoin.doctype AS pjoin_doctype, " + "pjoin.type AS pjoin_type FROM " + "(SELECT actual_documents.id AS id, " "actual_documents.send_method AS send_method, " - "actual_documents.id AS id, 'actual' AS type " - "FROM actual_documents) AS pjoin", + "actual_documents.doctype AS doctype, " + "'actual' AS type FROM actual_documents) AS pjoin", ) def test_column_attr_names(self): @@ -886,14 +886,14 @@ class ConcreteExtensionConfigTest( session.query(Document), "SELECT pjoin.documenttype AS pjoin_documenttype, " "pjoin.id AS pjoin_id, pjoin.type AS pjoin_type FROM " - "(SELECT offers.documenttype AS documenttype, offers.id AS id, " + "(SELECT offers.id AS id, offers.documenttype AS documenttype, " "'offer' AS type FROM offers) AS pjoin", ) self.assert_compile( session.query(Document.documentType), "SELECT pjoin.documenttype AS pjoin_documenttype FROM " - "(SELECT offers.documenttype AS documenttype, offers.id AS id, " + "(SELECT offers.id AS id, offers.documenttype AS documenttype, " "'offer' AS type FROM offers) AS pjoin", ) diff --git a/test/ext/mypy/files/inspect.py b/test/ext/mypy/inspection_inspect.py index c67b515f4..c67b515f4 100644 --- a/test/ext/mypy/files/inspect.py +++ b/test/ext/mypy/inspection_inspect.py diff --git a/test/ext/mypy/plain_files/engine_inspection.py b/test/ext/mypy/plain_files/engine_inspection.py new file mode 100644 index 000000000..1a1649e4e --- /dev/null +++ b/test/ext/mypy/plain_files/engine_inspection.py @@ -0,0 +1,24 @@ +import typing + +from sqlalchemy import create_engine +from sqlalchemy import inspect + + +e = create_engine("sqlite://") + +insp = inspect(e) + +cols = insp.get_columns("some_table") + +c1 = cols[0] + +if typing.TYPE_CHECKING: + + # EXPECTED_TYPE: sqlalchemy.engine.base.Engine + reveal_type(e) + + # EXPECTED_TYPE: sqlalchemy.engine.reflection.Inspector.* + reveal_type(insp) + + # EXPECTED_TYPE: .*list.*TypedDict.*ReflectedColumn.* + reveal_type(cols) diff --git a/test/ext/mypy/plain_files/experimental_relationship.py b/test/ext/mypy/plain_files/experimental_relationship.py new file mode 100644 index 000000000..e97a9598b --- /dev/null +++ b/test/ext/mypy/plain_files/experimental_relationship.py @@ -0,0 +1,69 @@ +"""this suite experiments with other kinds of relationship syntaxes. + +""" +import typing +from typing import List +from typing import Optional +from typing import Set + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + # this currently doesnt generate an error. not sure how to get the + # overloads to hit this one, nor am i sure i really want to do that + # anyway + name_this_works_atm: Mapped[str] = mapped_column(nullable=True) + + extra: Mapped[Optional[str]] = mapped_column() + extra_name: Mapped[Optional[str]] = mapped_column("extra_name") + + addresses_style_one: Mapped[List["Address"]] = relationship() + addresses_style_two: Mapped[Set["Address"]] = relationship() + + +class Address(Base): + __tablename__ = "address" + + id = mapped_column(Integer, primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email = mapped_column(String, nullable=False) + email_name = mapped_column("email_name", String, nullable=False) + + user_style_one: Mapped[User] = relationship() + user_style_two: Mapped["User"] = relationship() + + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] + reveal_type(User.extra) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] + reveal_type(User.extra_name) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\] + reveal_type(Address.email) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\] + reveal_type(Address.email_name) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[experimental_relationship.Address\]\] + reveal_type(User.addresses_style_one) + + # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*\[experimental_relationship.Address\]\] + reveal_type(User.addresses_style_two) diff --git a/test/ext/mypy/plain_files/mapped_column.py b/test/ext/mypy/plain_files/mapped_column.py new file mode 100644 index 000000000..b20beeb3a --- /dev/null +++ b/test/ext/mypy/plain_files/mapped_column.py @@ -0,0 +1,92 @@ +from typing import Optional + +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class X(Base): + __tablename__ = "x" + + id: Mapped[int] = mapped_column(primary_key=True) + int_id: Mapped[int] = mapped_column(Integer, primary_key=True) + + # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + err_int_id: Mapped[Optional[int]] = mapped_column( + Integer, primary_key=True + ) + + id_name: Mapped[int] = mapped_column("id_name", primary_key=True) + int_id_name: Mapped[int] = mapped_column( + "int_id_name", Integer, primary_key=True + ) + + # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + err_int_id_name: Mapped[Optional[int]] = mapped_column( + "err_int_id_name", Integer, primary_key=True + ) + + # note we arent getting into primary_key=True / nullable=True here. + # leaving that as undefined for now + + a: Mapped[str] = mapped_column() + b: Mapped[Optional[str]] = mapped_column() + + # can't detect error because no SQL type is present + c: Mapped[str] = mapped_column(nullable=True) + d: Mapped[str] = mapped_column(nullable=False) + + e: Mapped[Optional[str]] = mapped_column(nullable=True) + + # can't detect error because no SQL type is present + f: Mapped[Optional[str]] = mapped_column(nullable=False) + + g: Mapped[str] = mapped_column(String) + h: Mapped[Optional[str]] = mapped_column(String) + + # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + i: Mapped[str] = mapped_column(String, nullable=True) + + j: Mapped[str] = mapped_column(String, nullable=False) + + k: Mapped[Optional[str]] = mapped_column(String, nullable=True) + + # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type + l: Mapped[Optional[str]] = mapped_column(String, nullable=False) + + a_name: Mapped[str] = mapped_column("a_name") + b_name: Mapped[Optional[str]] = mapped_column("b_name") + + # can't detect error because no SQL type is present + c_name: Mapped[str] = mapped_column("c_name", nullable=True) + d_name: Mapped[str] = mapped_column("d_name", nullable=False) + + e_name: Mapped[Optional[str]] = mapped_column("e_name", nullable=True) + + # can't detect error because no SQL type is present + f_name: Mapped[Optional[str]] = mapped_column("f_name", nullable=False) + + g_name: Mapped[str] = mapped_column("g_name", String) + h_name: Mapped[Optional[str]] = mapped_column("h_name", String) + + # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + i_name: Mapped[str] = mapped_column("i_name", String, nullable=True) + + j_name: Mapped[str] = mapped_column("j_name", String, nullable=False) + + k_name: Mapped[Optional[str]] = mapped_column( + "k_name", String, nullable=True + ) + + l_name: Mapped[Optional[str]] = mapped_column( + "l_name", + # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type + String, + nullable=False, + ) diff --git a/test/ext/mypy/plain_files/trad_relationship_uselist.py b/test/ext/mypy/plain_files/trad_relationship_uselist.py new file mode 100644 index 000000000..a372fe2d1 --- /dev/null +++ b/test/ext/mypy/plain_files/trad_relationship_uselist.py @@ -0,0 +1,133 @@ +"""traditional relationship patterns with explicit uselist. + + +""" +import typing +from typing import cast +from typing import Dict +from typing import List +from typing import Set +from typing import Type + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.orm.collections import attribute_mapped_collection + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id = mapped_column(Integer, primary_key=True) + name = mapped_column(String, nullable=False) + + addresses_style_one: Mapped[List["Address"]] = relationship( + "Address", uselist=True + ) + + addresses_style_two: Mapped[Set["Address"]] = relationship( + "Address", collection_class=set + ) + + addresses_style_three = relationship("Address", collection_class=set) + + addresses_style_three_cast = relationship( + cast(Type["Address"], "Address"), collection_class=set + ) + + addresses_style_four = relationship("Address", collection_class=list) + + +class Address(Base): + __tablename__ = "address" + + id = mapped_column(Integer, primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email = mapped_column(String, nullable=False) + + user_style_one = relationship(User, uselist=False) + + user_style_one_typed: Mapped[User] = relationship(User, uselist=False) + + user_style_two = relationship("User", uselist=False) + + user_style_two_typed: Mapped["User"] = relationship("User", uselist=False) + + # these is obviously not correct relationally but want to see the typing + # work out with a real class passed as the argument + user_style_three: Mapped[List[User]] = relationship(User, uselist=True) + + user_style_four: Mapped[List[User]] = relationship("User", uselist=True) + + user_style_five: Mapped[List[User]] = relationship(User, uselist=True) + + user_style_six: Mapped[Set[User]] = relationship( + User, uselist=True, collection_class=set + ) + + user_style_seven = relationship(User, uselist=True, collection_class=set) + + user_style_eight = relationship(User, uselist=True, collection_class=list) + + user_style_nine = relationship(User, uselist=True) + + user_style_ten = relationship( + User, collection_class=attribute_mapped_collection("name") + ) + + user_style_ten_typed: Mapped[Dict[str, User]] = relationship( + User, collection_class=attribute_mapped_collection("name") + ) + + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.Address\]\] + reveal_type(User.addresses_style_one) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\] + reveal_type(User.addresses_style_two) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[Any\]\] + reveal_type(User.addresses_style_three) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\] + reveal_type(User.addresses_style_three_cast) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[Any\]\] + reveal_type(User.addresses_style_four) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + reveal_type(Address.user_style_one) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + reveal_type(Address.user_style_one_typed) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + reveal_type(Address.user_style_two) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + reveal_type(Address.user_style_two_typed) + + # reveal_type(Address.user_style_six) + + # reveal_type(Address.user_style_seven) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\] + reveal_type(Address.user_style_eight) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\] + reveal_type(Address.user_style_nine) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + reveal_type(Address.user_style_ten) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*\[builtins.str, trad_relationship_uselist.User\]\] + reveal_type(Address.user_style_ten_typed) diff --git a/test/ext/mypy/plain_files/traditional_relationship.py b/test/ext/mypy/plain_files/traditional_relationship.py new file mode 100644 index 000000000..473ccb282 --- /dev/null +++ b/test/ext/mypy/plain_files/traditional_relationship.py @@ -0,0 +1,88 @@ +"""Here we illustrate 'traditional' relationship that looks as much like +1.x SQLAlchemy as possible. We want to illustrate that users can apply +Mapped[...] on the left hand side and that this will work in all cases. +This requires that the return type of relationship is based on Any, +if no uselists are present. + +""" +import typing +from typing import List +from typing import Set + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id = mapped_column(Integer, primary_key=True) + name = mapped_column(String, nullable=False) + + addresses_style_one: Mapped[List["Address"]] = relationship("Address") + + addresses_style_two: Mapped[Set["Address"]] = relationship( + "Address", collection_class=set + ) + + +class Address(Base): + __tablename__ = "address" + + id = mapped_column(Integer, primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email = mapped_column(String, nullable=False) + + user_style_one = relationship(User) + + user_style_one_typed: Mapped[User] = relationship(User) + + user_style_two = relationship("User") + + user_style_two_typed: Mapped["User"] = relationship("User") + + # this is obviously not correct relationally but want to see the typing + # work out + user_style_three: Mapped[List[User]] = relationship(User) + + user_style_four: Mapped[List[User]] = relationship("User") + + user_style_five = relationship(User, collection_class=set) + + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.Address\]\] + reveal_type(User.addresses_style_one) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.Address\]\] + reveal_type(User.addresses_style_two) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + reveal_type(Address.user_style_one) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + reveal_type(Address.user_style_one_typed) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + reveal_type(Address.user_style_two) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + reveal_type(Address.user_style_two_typed) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\] + reveal_type(Address.user_style_three) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\] + reveal_type(Address.user_style_four) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.User\]\] + reveal_type(Address.user_style_five) diff --git a/test/ext/mypy/files/abstract_one.py b/test/ext/mypy/plugin_files/abstract_one.py index d11631d75..d11631d75 100644 --- a/test/ext/mypy/files/abstract_one.py +++ b/test/ext/mypy/plugin_files/abstract_one.py diff --git a/test/ext/mypy/files/as_declarative.py b/test/ext/mypy/plugin_files/as_declarative.py index 08f08f913..08f08f913 100644 --- a/test/ext/mypy/files/as_declarative.py +++ b/test/ext/mypy/plugin_files/as_declarative.py diff --git a/test/ext/mypy/files/as_declarative_base.py b/test/ext/mypy/plugin_files/as_declarative_base.py index ba62e7276..ba62e7276 100644 --- a/test/ext/mypy/files/as_declarative_base.py +++ b/test/ext/mypy/plugin_files/as_declarative_base.py diff --git a/test/ext/mypy/files/boolean_col.py b/test/ext/mypy/plugin_files/boolean_col.py index 3e361ad10..3e361ad10 100644 --- a/test/ext/mypy/files/boolean_col.py +++ b/test/ext/mypy/plugin_files/boolean_col.py diff --git a/test/ext/mypy/files/cols_noninferred_plain_nonopt.py b/test/ext/mypy/plugin_files/cols_noninferred_plain_nonopt.py index a2825e003..a2825e003 100644 --- a/test/ext/mypy/files/cols_noninferred_plain_nonopt.py +++ b/test/ext/mypy/plugin_files/cols_noninferred_plain_nonopt.py diff --git a/test/ext/mypy/files/cols_notype_on_fk_col.py b/test/ext/mypy/plugin_files/cols_notype_on_fk_col.py index 3195714ae..3195714ae 100644 --- a/test/ext/mypy/files/cols_notype_on_fk_col.py +++ b/test/ext/mypy/plugin_files/cols_notype_on_fk_col.py diff --git a/test/ext/mypy/files/complete_orm_no_plugin.py b/test/ext/mypy/plugin_files/complete_orm_no_plugin.py index 53291501a..53291501a 100644 --- a/test/ext/mypy/files/complete_orm_no_plugin.py +++ b/test/ext/mypy/plugin_files/complete_orm_no_plugin.py diff --git a/test/ext/mypy/files/composite_props.py b/test/ext/mypy/plugin_files/composite_props.py index f92b93c57..d717ca048 100644 --- a/test/ext/mypy/files/composite_props.py +++ b/test/ext/mypy/plugin_files/composite_props.py @@ -52,7 +52,7 @@ v1 = Vertex(start=Point(3, 4), end=Point(5, 6)) # I'm not even sure composites support this but it should work from a # typing perspective -stmt = select(v1).where(Vertex.start.in_([Point(3, 4)])) +stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)])) p1: Point = v1.start p2: Point = v1.end diff --git a/test/ext/mypy/files/constr_cols_only.py b/test/ext/mypy/plugin_files/constr_cols_only.py index cd4da5586..cd4da5586 100644 --- a/test/ext/mypy/files/constr_cols_only.py +++ b/test/ext/mypy/plugin_files/constr_cols_only.py diff --git a/test/ext/mypy/files/dataclasses_workaround.py b/test/ext/mypy/plugin_files/dataclasses_workaround.py index 56c61b333..9928b5a33 100644 --- a/test/ext/mypy/files/dataclasses_workaround.py +++ b/test/ext/mypy/plugin_files/dataclasses_workaround.py @@ -4,6 +4,8 @@ from __future__ import annotations from dataclasses import dataclass from dataclasses import field +from typing import Any +from typing import Dict from typing import List from typing import Optional from typing import TYPE_CHECKING @@ -40,7 +42,7 @@ class User: if TYPE_CHECKING: _mypy_mapped_attrs = [id, name, fullname, nickname, addresses] - __mapper_args__ = { # type: ignore + __mapper_args__: Dict[str, Any] = { "properties": {"addresses": relationship("Address")} } diff --git a/test/ext/mypy/files/decl_attrs_one.py b/test/ext/mypy/plugin_files/decl_attrs_one.py index 1f2261cfc..1f2261cfc 100644 --- a/test/ext/mypy/files/decl_attrs_one.py +++ b/test/ext/mypy/plugin_files/decl_attrs_one.py diff --git a/test/ext/mypy/files/decl_attrs_two.py b/test/ext/mypy/plugin_files/decl_attrs_two.py index a20af490d..a20af490d 100644 --- a/test/ext/mypy/files/decl_attrs_two.py +++ b/test/ext/mypy/plugin_files/decl_attrs_two.py diff --git a/test/ext/mypy/files/decl_base_subclass_one.py b/test/ext/mypy/plugin_files/decl_base_subclass_one.py index abe28a495..abe28a495 100644 --- a/test/ext/mypy/files/decl_base_subclass_one.py +++ b/test/ext/mypy/plugin_files/decl_base_subclass_one.py diff --git a/test/ext/mypy/files/decl_base_subclass_two.py b/test/ext/mypy/plugin_files/decl_base_subclass_two.py index 78b7a9b63..78b7a9b63 100644 --- a/test/ext/mypy/files/decl_base_subclass_two.py +++ b/test/ext/mypy/plugin_files/decl_base_subclass_two.py diff --git a/test/ext/mypy/files/declarative_base_dynamic.py b/test/ext/mypy/plugin_files/declarative_base_dynamic.py index eee9b3110..eee9b3110 100644 --- a/test/ext/mypy/files/declarative_base_dynamic.py +++ b/test/ext/mypy/plugin_files/declarative_base_dynamic.py diff --git a/test/ext/mypy/files/declarative_base_explicit.py b/test/ext/mypy/plugin_files/declarative_base_explicit.py index b1b02bfb8..b1b02bfb8 100644 --- a/test/ext/mypy/files/declarative_base_explicit.py +++ b/test/ext/mypy/plugin_files/declarative_base_explicit.py diff --git a/test/ext/mypy/files/ensure_descriptor_type_fully_inferred.py b/test/ext/mypy/plugin_files/ensure_descriptor_type_fully_inferred.py index 1a8904147..1a8904147 100644 --- a/test/ext/mypy/files/ensure_descriptor_type_fully_inferred.py +++ b/test/ext/mypy/plugin_files/ensure_descriptor_type_fully_inferred.py diff --git a/test/ext/mypy/files/ensure_descriptor_type_noninferred.py b/test/ext/mypy/plugin_files/ensure_descriptor_type_noninferred.py index b1dabe8dc..b1dabe8dc 100644 --- a/test/ext/mypy/files/ensure_descriptor_type_noninferred.py +++ b/test/ext/mypy/plugin_files/ensure_descriptor_type_noninferred.py diff --git a/test/ext/mypy/files/ensure_descriptor_type_semiinferred.py b/test/ext/mypy/plugin_files/ensure_descriptor_type_semiinferred.py index 2154ff074..2154ff074 100644 --- a/test/ext/mypy/files/ensure_descriptor_type_semiinferred.py +++ b/test/ext/mypy/plugin_files/ensure_descriptor_type_semiinferred.py diff --git a/test/ext/mypy/files/enum_col.py b/test/ext/mypy/plugin_files/enum_col.py index cfea38803..cfea38803 100644 --- a/test/ext/mypy/files/enum_col.py +++ b/test/ext/mypy/plugin_files/enum_col.py diff --git a/test/ext/mypy/files/imperative_table.py b/test/ext/mypy/plugin_files/imperative_table.py index 0548a7926..0548a7926 100644 --- a/test/ext/mypy/files/imperative_table.py +++ b/test/ext/mypy/plugin_files/imperative_table.py diff --git a/test/ext/mypy/files/invalid_noninferred_lh_type.py b/test/ext/mypy/plugin_files/invalid_noninferred_lh_type.py index 5084de722..5084de722 100644 --- a/test/ext/mypy/files/invalid_noninferred_lh_type.py +++ b/test/ext/mypy/plugin_files/invalid_noninferred_lh_type.py diff --git a/test/ext/mypy/files/issue_7321.py b/test/ext/mypy/plugin_files/issue_7321.py index d4cd7f2c4..d4cd7f2c4 100644 --- a/test/ext/mypy/files/issue_7321.py +++ b/test/ext/mypy/plugin_files/issue_7321.py diff --git a/test/ext/mypy/files/issue_7321_part2.py b/test/ext/mypy/plugin_files/issue_7321_part2.py index 4227f2797..4227f2797 100644 --- a/test/ext/mypy/files/issue_7321_part2.py +++ b/test/ext/mypy/plugin_files/issue_7321_part2.py diff --git a/test/ext/mypy/files/mapped_attr_assign.py b/test/ext/mypy/plugin_files/mapped_attr_assign.py index 06bc24d9e..06bc24d9e 100644 --- a/test/ext/mypy/files/mapped_attr_assign.py +++ b/test/ext/mypy/plugin_files/mapped_attr_assign.py diff --git a/test/ext/mypy/files/mixin_not_mapped.py b/test/ext/mypy/plugin_files/mixin_not_mapped.py index 9a4865eb6..9a4865eb6 100644 --- a/test/ext/mypy/files/mixin_not_mapped.py +++ b/test/ext/mypy/plugin_files/mixin_not_mapped.py diff --git a/test/ext/mypy/files/mixin_one.py b/test/ext/mypy/plugin_files/mixin_one.py index a471edf6c..a471edf6c 100644 --- a/test/ext/mypy/files/mixin_one.py +++ b/test/ext/mypy/plugin_files/mixin_one.py diff --git a/test/ext/mypy/files/mixin_three.py b/test/ext/mypy/plugin_files/mixin_three.py index cb8e30df8..cb8e30df8 100644 --- a/test/ext/mypy/files/mixin_three.py +++ b/test/ext/mypy/plugin_files/mixin_three.py diff --git a/test/ext/mypy/files/mixin_two.py b/test/ext/mypy/plugin_files/mixin_two.py index c4dc61097..897ce8249 100644 --- a/test/ext/mypy/files/mixin_two.py +++ b/test/ext/mypy/plugin_files/mixin_two.py @@ -6,6 +6,7 @@ from sqlalchemy import String from sqlalchemy.orm import deferred from sqlalchemy.orm import Mapped from sqlalchemy.orm import registry +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship from sqlalchemy.orm.decl_api import declared_attr from sqlalchemy.orm.interfaces import MapperProperty @@ -36,11 +37,11 @@ class HasAMixin: return relationship("A", back_populates="bs") @declared_attr - def a3(cls) -> relationship["A"]: + def a3(cls) -> Relationship["A"]: return relationship("A", back_populates="bs") @declared_attr - def c1(cls) -> relationship[C]: + def c1(cls) -> Relationship[C]: return relationship(C, back_populates="bs") @declared_attr diff --git a/test/ext/mypy/files/mixin_w_tablename.py b/test/ext/mypy/plugin_files/mixin_w_tablename.py index cfbe83d35..cfbe83d35 100644 --- a/test/ext/mypy/files/mixin_w_tablename.py +++ b/test/ext/mypy/plugin_files/mixin_w_tablename.py diff --git a/test/ext/mypy/files/orderinglist1.py b/test/ext/mypy/plugin_files/orderinglist1.py index 661d55a7b..661d55a7b 100644 --- a/test/ext/mypy/files/orderinglist1.py +++ b/test/ext/mypy/plugin_files/orderinglist1.py diff --git a/test/ext/mypy/files/orderinglist2.py b/test/ext/mypy/plugin_files/orderinglist2.py index eb50c5391..eb50c5391 100644 --- a/test/ext/mypy/files/orderinglist2.py +++ b/test/ext/mypy/plugin_files/orderinglist2.py diff --git a/test/ext/mypy/files/other_mapper_props.py b/test/ext/mypy/plugin_files/other_mapper_props.py index d87165fea..d87165fea 100644 --- a/test/ext/mypy/files/other_mapper_props.py +++ b/test/ext/mypy/plugin_files/other_mapper_props.py diff --git a/test/ext/mypy/files/plugin_doesnt_break_one.py b/test/ext/mypy/plugin_files/plugin_doesnt_break_one.py index 19cb2bfb4..19cb2bfb4 100644 --- a/test/ext/mypy/files/plugin_doesnt_break_one.py +++ b/test/ext/mypy/plugin_files/plugin_doesnt_break_one.py diff --git a/test/ext/mypy/files/relationship_6255_one.py b/test/ext/mypy/plugin_files/relationship_6255_one.py index e5a180b47..0c8e3c4f6 100644 --- a/test/ext/mypy/files/relationship_6255_one.py +++ b/test/ext/mypy/plugin_files/relationship_6255_one.py @@ -1,13 +1,13 @@ from typing import List from typing import Optional -from sqlalchemy import Column from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String from sqlalchemy.orm import declarative_base from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship Base = declarative_base() @@ -16,8 +16,8 @@ Base = declarative_base() class User(Base): __tablename__ = "user" - id = Column(Integer, primary_key=True) - name = Column(String) + id = mapped_column(Integer, primary_key=True) + name = mapped_column(String, nullable=True) addresses: Mapped[List["Address"]] = relationship( "Address", back_populates="user" @@ -31,10 +31,10 @@ class User(Base): class Address(Base): __tablename__ = "address" - id = Column(Integer, primary_key=True) - user_id: int = Column(ForeignKey("user.id")) + id = mapped_column(Integer, primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) - user: "User" = relationship("User", back_populates="addresses") + user: Mapped["User"] = relationship("User", back_populates="addresses") @property def some_other_property(self) -> Optional[str]: diff --git a/test/ext/mypy/files/relationship_6255_three.py b/test/ext/mypy/plugin_files/relationship_6255_three.py index 121d8de40..121d8de40 100644 --- a/test/ext/mypy/files/relationship_6255_three.py +++ b/test/ext/mypy/plugin_files/relationship_6255_three.py diff --git a/test/ext/mypy/files/relationship_6255_two.py b/test/ext/mypy/plugin_files/relationship_6255_two.py index 121d8de40..121d8de40 100644 --- a/test/ext/mypy/files/relationship_6255_two.py +++ b/test/ext/mypy/plugin_files/relationship_6255_two.py diff --git a/test/ext/mypy/files/relationship_direct_cls.py b/test/ext/mypy/plugin_files/relationship_direct_cls.py index 1c4efdee2..1c4efdee2 100644 --- a/test/ext/mypy/files/relationship_direct_cls.py +++ b/test/ext/mypy/plugin_files/relationship_direct_cls.py diff --git a/test/ext/mypy/files/relationship_err1.py b/test/ext/mypy/plugin_files/relationship_err1.py index 46e7067d3..ba3783f05 100644 --- a/test/ext/mypy/files/relationship_err1.py +++ b/test/ext/mypy/plugin_files/relationship_err1.py @@ -27,4 +27,5 @@ class A(Base): b_id: int = Column(ForeignKey("b.id")) # EXPECTED: Sending uselist=False and collection_class at the same time does not make sense # noqa + # EXPECTED_MYPY_RE: No overload variant of "relationship" matches argument types b: B = relationship(B, uselist=False, collection_class=set) diff --git a/test/ext/mypy/files/relationship_err2.py b/test/ext/mypy/plugin_files/relationship_err2.py index 4057baeb3..4057baeb3 100644 --- a/test/ext/mypy/files/relationship_err2.py +++ b/test/ext/mypy/plugin_files/relationship_err2.py diff --git a/test/ext/mypy/files/relationship_err3.py b/test/ext/mypy/plugin_files/relationship_err3.py index aa76ae1f0..aa76ae1f0 100644 --- a/test/ext/mypy/files/relationship_err3.py +++ b/test/ext/mypy/plugin_files/relationship_err3.py diff --git a/test/ext/mypy/files/sa_module_prefix.py b/test/ext/mypy/plugin_files/sa_module_prefix.py index a37ae6b06..a37ae6b06 100644 --- a/test/ext/mypy/files/sa_module_prefix.py +++ b/test/ext/mypy/plugin_files/sa_module_prefix.py diff --git a/test/ext/mypy/files/t_6950.py b/test/ext/mypy/plugin_files/t_6950.py index 3ebbf6638..3ebbf6638 100644 --- a/test/ext/mypy/files/t_6950.py +++ b/test/ext/mypy/plugin_files/t_6950.py diff --git a/test/ext/mypy/files/type_decorator.py b/test/ext/mypy/plugin_files/type_decorator.py index 07a13caee..07a13caee 100644 --- a/test/ext/mypy/files/type_decorator.py +++ b/test/ext/mypy/plugin_files/type_decorator.py diff --git a/test/ext/mypy/files/typeless_fk_col_cant_infer.py b/test/ext/mypy/plugin_files/typeless_fk_col_cant_infer.py index beb4a7a5d..beb4a7a5d 100644 --- a/test/ext/mypy/files/typeless_fk_col_cant_infer.py +++ b/test/ext/mypy/plugin_files/typeless_fk_col_cant_infer.py diff --git a/test/ext/mypy/files/typing_err1.py b/test/ext/mypy/plugin_files/typing_err1.py index f262cd55b..f262cd55b 100644 --- a/test/ext/mypy/files/typing_err1.py +++ b/test/ext/mypy/plugin_files/typing_err1.py diff --git a/test/ext/mypy/files/typing_err2.py b/test/ext/mypy/plugin_files/typing_err2.py index adc50f989..ec5635875 100644 --- a/test/ext/mypy/files/typing_err2.py +++ b/test/ext/mypy/plugin_files/typing_err2.py @@ -3,6 +3,7 @@ from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.orm import declared_attr from sqlalchemy.orm import registry +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship reg: registry = registry() @@ -29,8 +30,8 @@ class Foo: # EXPECTED: Can't infer type from @declared_attr on function 'some_relationship' # noqa @declared_attr - # EXPECTED_MYPY: Missing type parameters for generic type "relationship" - def some_relationship(cls) -> relationship: + # EXPECTED_MYPY: Missing type parameters for generic type "Relationship" + def some_relationship(cls) -> Relationship: return relationship("Bar") diff --git a/test/ext/mypy/files/typing_err3.py b/test/ext/mypy/plugin_files/typing_err3.py index 5383f8956..466e636a7 100644 --- a/test/ext/mypy/files/typing_err3.py +++ b/test/ext/mypy/plugin_files/typing_err3.py @@ -49,6 +49,5 @@ class Address(Base): @declared_attr # EXPECTED_MYPY: Invalid type comment or annotation def thisisweird(cls) -> Column(String): - # with the bad annotation mypy seems to not go into the - # function body + # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa return Column(False) diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index cc8d8955f..6df21e46c 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -3,16 +3,54 @@ import re import shutil import sys import tempfile +from typing import Any +from typing import cast +from typing import List +from typing import Tuple +import sqlalchemy from sqlalchemy import testing from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +def _file_combinations(dirname): + path = os.path.join(os.path.dirname(__file__), dirname) + files = [] + for f in os.listdir(path): + if f.endswith(".py"): + files.append(os.path.join(os.path.dirname(__file__), dirname, f)) + + for extra_dir in testing.config.options.mypy_extra_test_paths: + if extra_dir and os.path.isdir(extra_dir): + for f in os.listdir(os.path.join(extra_dir, dirname)): + if f.endswith(".py"): + files.append(os.path.join(extra_dir, dirname, f)) + return files + + +def _incremental_dirs(): + path = os.path.join(os.path.dirname(__file__), "incremental") + files = [] + for d in os.listdir(path): + if os.path.isdir(os.path.join(path, d)): + files.append( + os.path.join(os.path.dirname(__file__), "incremental", d) + ) + + for extra_dir in testing.config.options.mypy_extra_test_paths: + if extra_dir and os.path.isdir(extra_dir): + for d in os.listdir(os.path.join(extra_dir, "incremental")): + if os.path.isdir(os.path.join(path, d)): + files.append(os.path.join(extra_dir, "incremental", d)) + return files + + @testing.add_to_marker.mypy class MypyPluginTest(fixtures.TestBase): - __requires__ = ("sqlalchemy2_stubs",) + __tags__ = ("mypy",) + __requires__ = ("no_sqlalchemy2_stubs",) @testing.fixture(scope="function") def per_func_cachedir(self): @@ -25,22 +63,50 @@ class MypyPluginTest(fixtures.TestBase): yield item def _cachedir(self): + sqlalchemy_path = os.path.dirname(os.path.dirname(sqlalchemy.__file__)) + + # for a pytest from my local ./lib/ , i need mypy_path. + # for a tox run where sqlalchemy is in site_packages, mypy complains + # "../python3.10/site-packages is in the MYPYPATH. Please remove it." + # previously when we used sqlalchemy2-stubs, it would just be + # installed as a dependency, which is why mypy_path wasn't needed + # then, but I like to be able to run the test suite from the local + # ./lib/ as well. + + if "site-packages" not in sqlalchemy_path: + mypy_path = f"mypy_path={sqlalchemy_path}" + else: + mypy_path = "" + with tempfile.TemporaryDirectory() as cachedir: with open( os.path.join(cachedir, "sqla_mypy_config.cfg"), "w" ) as config_file: config_file.write( - """ + f""" [mypy]\n plugins = sqlalchemy.ext.mypy.plugin\n + show_error_codes = True\n + {mypy_path} + disable_error_code = no-untyped-call + + [mypy-sqlalchemy.*] + ignore_errors = True + """ ) with open( os.path.join(cachedir, "plain_mypy_config.cfg"), "w" ) as config_file: config_file.write( - """ + f""" [mypy]\n + show_error_codes = True\n + {mypy_path} + disable_error_code = var-annotated,no-untyped-call + [mypy-sqlalchemy.*] + ignore_errors = True + """ ) yield cachedir @@ -70,24 +136,12 @@ class MypyPluginTest(fixtures.TestBase): return run - def _incremental_dirs(): - path = os.path.join(os.path.dirname(__file__), "incremental") - files = [] - for d in os.listdir(path): - if os.path.isdir(os.path.join(path, d)): - files.append( - os.path.join(os.path.dirname(__file__), "incremental", d) - ) - - for extra_dir in testing.config.options.mypy_extra_test_paths: - if extra_dir and os.path.isdir(extra_dir): - for d in os.listdir(os.path.join(extra_dir, "incremental")): - if os.path.isdir(os.path.join(path, d)): - files.append(os.path.join(extra_dir, "incremental", d)) - return files - @testing.combinations( - *[(pathname,) for pathname in _incremental_dirs()], argnames="pathname" + *[ + (pathname, testing.exclusions.closed()) + for pathname in _incremental_dirs() + ], + argnames="pathname", ) @testing.requires.patch_library def test_incremental(self, mypy_runner, per_func_cachedir, pathname): @@ -131,33 +185,33 @@ class MypyPluginTest(fixtures.TestBase): % (patchfile, result[0]), ) - def _file_combinations(): - path = os.path.join(os.path.dirname(__file__), "files") - files = [] - for f in os.listdir(path): - if f.endswith(".py"): - files.append( - os.path.join(os.path.dirname(__file__), "files", f) - ) - - for extra_dir in testing.config.options.mypy_extra_test_paths: - if extra_dir and os.path.isdir(extra_dir): - for f in os.listdir(os.path.join(extra_dir, "files")): - if f.endswith(".py"): - files.append(os.path.join(extra_dir, "files", f)) - return files - @testing.combinations( - *[(filename,) for filename in _file_combinations()], argnames="path" + *( + cast( + List[Tuple[Any, ...]], + [ + ("w_plugin", os.path.basename(path), path, True) + for path in _file_combinations("plugin_files") + ], + ) + + cast( + List[Tuple[Any, ...]], + [ + ("plain", os.path.basename(path), path, False) + for path in _file_combinations("plain_files") + ], + ) + ), + argnames="filename,path,use_plugin", + id_="isaa", ) - def test_mypy(self, mypy_runner, path): - filename = os.path.basename(path) - use_plugin = True + def test_files(self, mypy_runner, filename, path, use_plugin): - expected_errors = [] - expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?: (.+)") + expected_messages = [] + expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") with open(path) as file_: + current_assert_messages = [] for num, line in enumerate(file_, 1): m = py_ver_re.match(line) if m: @@ -174,38 +228,79 @@ class MypyPluginTest(fixtures.TestBase): m = expected_re.match(line) if m: is_mypy = bool(m.group(1)) - expected_msg = m.group(2) - expected_msg = re.sub(r"# noqa ?.*", "", m.group(2)) - expected_errors.append( - (num, is_mypy, expected_msg.strip()) + is_re = bool(m.group(2)) + is_type = bool(m.group(3)) + + expected_msg = re.sub(r"# noqa ?.*", "", m.group(4)) + if is_type: + is_mypy = is_re = True + expected_msg = f'Revealed type is "{expected_msg}"' + current_assert_messages.append( + (is_mypy, is_re, expected_msg.strip()) + ) + elif current_assert_messages: + expected_messages.extend( + (num, is_mypy, is_re, expected_msg) + for ( + is_mypy, + is_re, + expected_msg, + ) in current_assert_messages ) + current_assert_messages[:] = [] result = mypy_runner(path, use_plugin=use_plugin) - if expected_errors: + if expected_messages: eq_(result[2], 1, msg=result) - print(result[0]) + output = [] - errors = [] - for e in result[0].split("\n"): + raw_lines = result[0].split("\n") + while raw_lines: + e = raw_lines.pop(0) if re.match(r".+\.py:\d+: error: .*", e): - errors.append(e) - - for num, is_mypy, msg in expected_errors: + output.append(("error", e)) + elif re.match( + r".+\.py:\d+: note: +(?:Possible overload|def ).*", e + ): + while raw_lines: + ol = raw_lines.pop(0) + if not re.match(r".+\.py:\d+: note: +def \[.*", ol): + break + elif re.match( + r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I + ): + pass + elif re.match(r".+\.py:\d+: note: .*", e): + output.append(("note", e)) + + for num, is_mypy, is_re, msg in expected_messages: msg = msg.replace("'", '"') prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else "" - for idx, errmsg in enumerate(errors): - if ( - f"{filename}:{num + 1}: error: {prefix}{msg}" + for idx, (typ, errmsg) in enumerate(output): + if is_re: + if re.match( + fr".*{filename}\:{num}\: {typ}\: {prefix}{msg}", # noqa E501 + errmsg, + ): + break + elif ( + f"{filename}:{num}: {typ}: {prefix}{msg}" in errmsg.replace("'", '"') ): break else: continue - del errors[idx] + del output[idx] - assert not errors, "errors remain: %s" % "\n".join(errors) + if output: + print("messages from mypy that were not consumed:") + print("\n".join(msg for _, msg in output)) + assert False, "errors and/or notes remain, see stdout" else: + if result[2] != 0: + print(result[0]) + eq_(result[2], 0, msg=result) diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index 9a05e1fae..76fd90fa8 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -231,7 +231,7 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest): serializer.loads(pickled_failing, users.metadata, None) def test_orm_join(self): - from sqlalchemy.orm.util import join + from sqlalchemy.orm import join j = join(User, Address, User.addresses) diff --git a/test/orm/declarative/__init__.py b/test/orm/declarative/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/test/orm/declarative/__init__.py diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 9651f6dbf..9f9f8e601 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -563,26 +563,6 @@ class DeclarativeMultiBaseTest( eq_(a1, Address(email="two")) eq_(a1.user, User(name="u1")) - def test_mapped_column_construct(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = "users" - - id = mapped_column("id", Integer, primary_key=True) - name = mapped_column(String(50)) - - Base.metadata.create_all(testing.db) - - u1 = User(id=1, name="u1") - sess = fixture_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_( - sess.query(User).all(), - [User(name="u1", id=1)], - ) - def test_back_populates_setup(self): class User(Base): __tablename__ = "users" @@ -1534,28 +1514,25 @@ class DeclarativeMultiBaseTest( yield go + @testing.combinations(Column, mapped_column, argnames="_column") def test_add_prop_auto( - self, require_metaclass, assert_user_address_mapping + self, require_metaclass, assert_user_address_mapping, _column ): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" - id = Column( - "id", Integer, primary_key=True, test_needs_autoincrement=True - ) + id = Column("id", Integer, primary_key=True) - User.name = Column("name", String(50)) + User.name = _column("name", String(50)) User.addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" - id = Column( - Integer, primary_key=True, test_needs_autoincrement=True - ) + id = _column(Integer, primary_key=True) - Address.email = Column(String(50), key="_email") - Address.user_id = Column( + Address.email = _column(String(50), key="_email") + Address.user_id = _column( "user_id", Integer, ForeignKey("users.id"), key="_user_id" ) @@ -1565,15 +1542,14 @@ class DeclarativeMultiBaseTest( assert_user_address_mapping(User, Address) - def test_add_prop_manual(self, assert_user_address_mapping): + @testing.combinations(Column, mapped_column, argnames="_column") + def test_add_prop_manual(self, assert_user_address_mapping, _column): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" - id = Column( - "id", Integer, primary_key=True, test_needs_autoincrement=True - ) + id = _column("id", Integer, primary_key=True) - add_mapped_attribute(User, "name", Column("name", String(50))) + add_mapped_attribute(User, "name", _column("name", String(50))) add_mapped_attribute( User, "addresses", relationship("Address", backref="user") ) @@ -1581,17 +1557,17 @@ class DeclarativeMultiBaseTest( class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" - id = Column( - Integer, primary_key=True, test_needs_autoincrement=True - ) + id = _column(Integer, primary_key=True) add_mapped_attribute( - Address, "email", Column(String(50), key="_email") + Address, "email", _column(String(50), key="_email") ) add_mapped_attribute( Address, "user_id", - Column("user_id", Integer, ForeignKey("users.id"), key="_user_id"), + _column( + "user_id", Integer, ForeignKey("users.id"), key="_user_id" + ), ) eq_(Address.__table__.c["id"].name, "id") @@ -1612,7 +1588,7 @@ class DeclarativeMultiBaseTest( assert ASub.brap.property is A.data.property assert isinstance( - ASub.brap.original_property, descriptor_props.SynonymProperty + ASub.brap.original_property, descriptor_props.Synonym ) def test_alt_name_attr_subclass_relationship_inline(self): @@ -1634,7 +1610,7 @@ class DeclarativeMultiBaseTest( assert ASub.brap.property is A.b.property assert isinstance( - ASub.brap.original_property, descriptor_props.SynonymProperty + ASub.brap.original_property, descriptor_props.Synonym ) ASub(brap=B()) @@ -1647,9 +1623,7 @@ class DeclarativeMultiBaseTest( A.brap = A.data assert A.brap.property is A.data.property - assert isinstance( - A.brap.original_property, descriptor_props.SynonymProperty - ) + assert isinstance(A.brap.original_property, descriptor_props.Synonym) def test_alt_name_attr_subclass_relationship_attrset( self, require_metaclass @@ -1668,9 +1642,7 @@ class DeclarativeMultiBaseTest( id = Column("id", Integer, primary_key=True) assert A.brap.property is A.b.property - assert isinstance( - A.brap.original_property, descriptor_props.SynonymProperty - ) + assert isinstance(A.brap.original_property, descriptor_props.Synonym) A(brap=B()) def test_eager_order_by(self): diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 97f0d560e..5be8237e2 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -1,3 +1,5 @@ +from operator import is_not + import sqlalchemy as sa from sqlalchemy import ForeignKey from sqlalchemy import func @@ -18,15 +20,18 @@ from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import events as orm_events from sqlalchemy.orm import has_inherited_table +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import registry from sqlalchemy.orm import relationship from sqlalchemy.orm import synonym from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -375,16 +380,88 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(m1.tables["user"].c.keys(), ["id", "name", "surname"]) eq_(m2.tables["user"].c.keys(), ["id", "username"]) - def test_not_allowed(self): + @testing.combinations(Column, mapped_column, argnames="_column") + @testing.combinations("strname", "colref", "objref", argnames="fk_type") + def test_fk_mixin(self, decl_base, fk_type, _column): + class Bar(decl_base): + __tablename__ = "bar" + + id = _column(Integer, primary_key=True) + + if fk_type == "strname": + fk = ForeignKey("bar.id") + elif fk_type == "colref": + fk = ForeignKey(Bar.__table__.c.id) + elif fk_type == "objref": + fk = ForeignKey(Bar.id) + else: + assert False + class MyMixin: - foo = Column(Integer, ForeignKey("bar.id")) + foo = _column(Integer, fk) - def go(): - class MyModel(Base, MyMixin): - __tablename__ = "foo" + class A(MyMixin, decl_base): + __tablename__ = "a" - assert_raises(sa.exc.InvalidRequestError, go) + id = _column(Integer, primary_key=True) + + class B(MyMixin, decl_base): + __tablename__ = "b" + + id = _column(Integer, primary_key=True) + + is_true(A.__table__.c.foo.references(Bar.__table__.c.id)) + is_true(B.__table__.c.foo.references(Bar.__table__.c.id)) + + fka = list(A.__table__.c.foo.foreign_keys)[0] + fkb = list(A.__table__.c.foo.foreign_keys)[0] + is_not(fka, fkb) + + @testing.combinations(Column, mapped_column, argnames="_column") + def test_fk_mixin_self_referential_error(self, decl_base, _column): + class MyMixin: + id = _column(Integer, primary_key=True) + foo = _column(Integer, ForeignKey(id)) + with expect_raises_message( + sa.exc.InvalidRequestError, + "Columns with foreign keys to non-table-bound columns " + "must be declared as @declared_attr", + ): + + class A(MyMixin, decl_base): + __tablename__ = "a" + + @testing.combinations(Column, mapped_column, argnames="_column") + def test_fk_mixin_self_referential_declared_attr(self, decl_base, _column): + class MyMixin: + id = _column(Integer, primary_key=True) + + @declared_attr + def foo(cls): + return _column(Integer, ForeignKey(cls.id)) + + class A(MyMixin, decl_base): + __tablename__ = "a" + + class B(MyMixin, decl_base): + __tablename__ = "b" + + is_true(A.__table__.c.foo.references(A.__table__.c.id)) + is_true(B.__table__.c.foo.references(B.__table__.c.id)) + + fka = list(A.__table__.c.foo.foreign_keys)[0] + fkb = list(A.__table__.c.foo.foreign_keys)[0] + is_not(fka, fkb) + + is_true(A.__table__.c.foo.references(A.__table__.c.id)) + is_true(B.__table__.c.foo.references(B.__table__.c.id)) + + fka = list(A.__table__.c.foo.foreign_keys)[0] + fkb = list(A.__table__.c.foo.foreign_keys)[0] + is_not(fka, fkb) + + def test_not_allowed(self): class MyRelMixin: foo = relationship("Bar") @@ -1013,7 +1090,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): __mapper_args__ = dict(polymorphic_identity="specific") assert Specific.__table__ is Generic.__table__ - eq_(list(Generic.__table__.c.keys()), ["id", "type", "value"]) + eq_(list(Generic.__table__.c.keys()), ["type", "value", "id"]) assert ( class_mapper(Specific).polymorphic_on is Generic.__table__.c.type ) @@ -1043,7 +1120,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(Specific.__table__.name, "specific") eq_( list(Generic.__table__.c.keys()), - ["timestamp", "id", "python_type"], + ["python_type", "timestamp", "id"], ) eq_(list(Specific.__table__.c.keys()), ["id"]) eq_(Generic.__table__.kwargs, {"mysql_engine": "InnoDB"}) @@ -1078,7 +1155,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(BaseType.__table__.name, "basetype") eq_( list(BaseType.__table__.c.keys()), - ["timestamp", "type", "id", "value"], + ["type", "id", "value", "timestamp"], ) eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"}) assert Single.__table__ is BaseType.__table__ @@ -1326,7 +1403,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_( list(Model.__table__.c.keys()), - ["col1", "col3", "col2", "col4", "id"], + ["id", "col1", "col3", "col2", "col4"], ) def test_honor_class_mro_one(self): @@ -1813,11 +1890,11 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): s = fixture_session() self.assert_compile( s.query(A), - "SELECT a.x AS a_x, a.x + :x_1 AS anon_1, a.id AS a_id FROM a", + "SELECT a.x + :x_1 AS anon_1, a.x AS a_x, a.id AS a_id FROM a", ) self.assert_compile( s.query(B), - "SELECT b.x AS b_x, b.x + :x_1 AS anon_1, b.id AS b_id FROM b", + "SELECT b.x + :x_1 AS anon_1, b.x AS b_x, b.id AS b_id FROM b", ) @testing.requires.predictable_gc @@ -2161,7 +2238,7 @@ class AbstractTest(DeclarativeTestBase): class C(B): c_value = Column(String) - eq_(sa.inspect(C).attrs.keys(), ["id", "name", "data", "c_value"]) + eq_(sa.inspect(C).attrs.keys(), ["id", "name", "c_value", "data"]) def test_implicit_abstract_viadecorator(self): @mapper_registry.mapped @@ -2178,7 +2255,7 @@ class AbstractTest(DeclarativeTestBase): class C(B): c_value = Column(String) - eq_(sa.inspect(C).attrs.keys(), ["id", "name", "data", "c_value"]) + eq_(sa.inspect(C).attrs.keys(), ["id", "name", "c_value", "data"]) def test_middle_abstract_inherits(self): # test for [ticket:3240] diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py new file mode 100644 index 000000000..c7022dc31 --- /dev/null +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from .test_typed_mapping import MappedColumnTest # noqa +from .test_typed_mapping import RelationshipLHSTest # noqa + +"""runs the annotation-sensitive tests from test_typed_mappings while +having ``from __future__ import annotations`` in effect. + +""" diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py new file mode 100644 index 000000000..71eb7ce42 --- /dev/null +++ b/test/orm/declarative/test_typed_mapping.py @@ -0,0 +1,1048 @@ +import dataclasses +import datetime +from decimal import Decimal +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from typing import TypeVar +from typing import Union + +from sqlalchemy import BIGINT +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import exc as sa_exc +from sqlalchemy import ForeignKey +from sqlalchemy import inspect +from sqlalchemy import Integer +from sqlalchemy import Numeric +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import VARCHAR +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import as_declarative +from sqlalchemy.orm import composite +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import deferred +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.orm import undefer +from sqlalchemy.orm.collections import attribute_mapped_collection +from sqlalchemy.orm.collections import MappedCollection +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_not +from sqlalchemy.testing import is_true +from sqlalchemy.testing.fixtures import fixture_session +from sqlalchemy.util.typing import Annotated + + +class DeclarativeBaseTest(fixtures.TestBase): + def test_class_getitem_as_declarative(self): + T = TypeVar("T", bound="CommonBase") # noqa + + class CommonBase(Generic[T]): + @classmethod + def boring(cls: Type[T]) -> Type[T]: + return cls + + @classmethod + def more_boring(cls: Type[T]) -> int: + return 27 + + @as_declarative() + class Base(CommonBase[T]): + foo = 1 + + class Tab(Base["Tab"]): + __tablename__ = "foo" + a = Column(Integer, primary_key=True) + + eq_(Tab.foo, 1) + is_(Tab.__table__, inspect(Tab).local_table) + eq_(Tab.boring(), Tab) + eq_(Tab.more_boring(), 27) + + with expect_raises(AttributeError): + Tab.non_existent + + +class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + def test_legacy_declarative_base(self): + typ = VARCHAR(50) + Base = declarative_base(type_annotation_map={str: typ}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + x: Mapped[int] + + is_(MyClass.__table__.c.data.type, typ) + is_true(MyClass.__table__.c.id.primary_key) + + def test_required_no_arg(self, decl_base): + with expect_raises_message( + sa_exc.ArgumentError, + r"Python typing annotation is required for attribute " + r'"A.data" when primary ' + r'argument\(s\) for "MappedColumn" construct are None or ' + r"not present", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data = mapped_column() + + def test_construct_rhs(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id = mapped_column("id", Integer, primary_key=True) + name = mapped_column(String(50)) + + self.assert_compile( + select(User), "SELECT users.id, users.name FROM users" + ) + eq_(User.__mapper__.primary_key, (User.__table__.c.id,)) + + def test_construct_lhs(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + data: Mapped[Optional[str]] = mapped_column() + + self.assert_compile( + select(User), "SELECT users.id, users.name, users.data FROM users" + ) + eq_(User.__mapper__.primary_key, (User.__table__.c.id,)) + is_false(User.__table__.c.id.nullable) + is_false(User.__table__.c.name.nullable) + is_true(User.__table__.c.data.nullable) + + def test_construct_lhs_omit_mapped_column(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + data: Mapped[Optional[str]] + x: Mapped[int] + y: Mapped[int] + created_at: Mapped[datetime.datetime] + + self.assert_compile( + select(User), + "SELECT users.id, users.name, users.data, users.x, " + "users.y, users.created_at FROM users", + ) + eq_(User.__mapper__.primary_key, (User.__table__.c.id,)) + is_false(User.__table__.c.id.nullable) + is_false(User.__table__.c.name.nullable) + is_true(User.__table__.c.data.nullable) + assert isinstance(User.__table__.c.created_at.type, DateTime) + + def test_construct_lhs_type_missing(self, decl_base): + class MyClass: + pass + + with expect_raises_message( + sa_exc.ArgumentError, + "Could not locate SQLAlchemy Core type for Python type: .*MyClass", + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[MyClass] = mapped_column() + + def test_construct_rhs_type_override_lhs(self, decl_base): + class Element(decl_base): + __tablename__ = "element" + + id: Mapped[int] = mapped_column(BIGINT, primary_key=True) + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(BIGINT, primary_key=True) + other_id: Mapped[int] = mapped_column(ForeignKey("element.id")) + data: Mapped[int] = mapped_column() + + # exact class test + is_(User.__table__.c.id.type.__class__, BIGINT) + is_(User.__table__.c.other_id.type.__class__, BIGINT) + is_(User.__table__.c.data.type.__class__, Integer) + + @testing.combinations(True, False, argnames="include_rhs_type") + def test_construct_nullability_overrides( + self, decl_base, include_rhs_type + ): + + if include_rhs_type: + args = (String,) + else: + args = () + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + + lnnl_rndf: Mapped[str] = mapped_column(*args) + lnnl_rnnl: Mapped[str] = mapped_column(*args, nullable=False) + lnnl_rnl: Mapped[str] = mapped_column(*args, nullable=True) + lnl_rndf: Mapped[Optional[str]] = mapped_column(*args) + lnl_rnnl: Mapped[Optional[str]] = mapped_column( + *args, nullable=False + ) + lnl_rnl: Mapped[Optional[str]] = mapped_column( + *args, nullable=True + ) + + is_false(User.__table__.c.lnnl_rndf.nullable) + is_false(User.__table__.c.lnnl_rnnl.nullable) + is_true(User.__table__.c.lnnl_rnl.nullable) + + is_true(User.__table__.c.lnl_rndf.nullable) + is_false(User.__table__.c.lnl_rnnl.nullable) + is_true(User.__table__.c.lnl_rnl.nullable) + + def test_fwd_refs(self, decl_base: Type[DeclarativeBase]): + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped["int"] = mapped_column(primary_key=True) + data_one: Mapped["str"] + + def test_annotated_types_as_keys(self, decl_base: Type[DeclarativeBase]): + """neat!!!""" + + str50 = Annotated[str, 50] + str30 = Annotated[str, 30] + opt_str50 = Optional[str50] + opt_str30 = Optional[str30] + + decl_base.registry.update_type_annotation_map( + {str50: String(50), str30: String(30)} + ) + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[str50] = mapped_column(primary_key=True) + data_one: Mapped[str30] + data_two: Mapped[opt_str30] + data_three: Mapped[str50] + data_four: Mapped[opt_str50] + data_five: Mapped[str] + data_six: Mapped[Optional[str]] + + eq_(MyClass.__table__.c.data_one.type.length, 30) + is_false(MyClass.__table__.c.data_one.nullable) + eq_(MyClass.__table__.c.data_two.type.length, 30) + is_true(MyClass.__table__.c.data_two.nullable) + eq_(MyClass.__table__.c.data_three.type.length, 50) + + def test_unions(self): + our_type = Numeric(10, 2) + + class Base(DeclarativeBase): + type_annotation_map = {Union[float, Decimal]: our_type} + + class User(Base): + __tablename__ = "users" + __table__: Table + + id: Mapped[int] = mapped_column(primary_key=True) + + data: Mapped[Union[float, Decimal]] = mapped_column() + reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + optional_data: Mapped[ + Optional[Union[float, Decimal]] + ] = mapped_column() + + # use Optional directly + reverse_optional_data: Mapped[ + Optional[Union[Decimal, float]] + ] = mapped_column() + + # use Union with None, same as Optional but presents differently + # (Optional object with __origin__ Union vs. Union) + reverse_u_optional_data: Mapped[ + Union[Decimal, float, None] + ] = mapped_column() + float_data: Mapped[float] = mapped_column() + decimal_data: Mapped[Decimal] = mapped_column() + + is_(User.__table__.c.data.type, our_type) + is_false(User.__table__.c.data.nullable) + is_(User.__table__.c.reverse_data.type, our_type) + is_(User.__table__.c.optional_data.type, our_type) + is_true(User.__table__.c.optional_data.nullable) + + is_(User.__table__.c.reverse_optional_data.type, our_type) + is_(User.__table__.c.reverse_u_optional_data.type, our_type) + is_true(User.__table__.c.reverse_optional_data.nullable) + is_true(User.__table__.c.reverse_u_optional_data.nullable) + + is_(User.__table__.c.float_data.type, our_type) + is_(User.__table__.c.decimal_data.type, our_type) + + def test_missing_mapped_lhs(self, decl_base): + with expect_raises_message( + ArgumentError, + r'Type annotation for "User.name" should use the ' + r'syntax "Mapped\[str\]" or "MappedColumn\[str\]"', + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: str = mapped_column() # type: ignore + + def test_construct_lhs_separate_name(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + data: Mapped[Optional[str]] = mapped_column("the_data") + + self.assert_compile( + select(User.data), "SELECT users.the_data FROM users" + ) + is_true(User.__table__.c.the_data.nullable) + + def test_construct_works_in_expr(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + + class Address(decl_base): + __tablename__ = "addresses" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id")) + + user = relationship(User, primaryjoin=user_id == User.id) + + self.assert_compile( + select(Address.user_id, User.id).join(Address.user), + "SELECT addresses.user_id, users.id FROM addresses " + "JOIN users ON addresses.user_id = users.id", + ) + + def test_construct_works_as_polymorphic_on(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] = mapped_column() + + __mapper_args__ = {"polymorphic_on": type} + + decl_base.registry.configure() + is_(User.__table__.c.type, User.__mapper__.polymorphic_on) + + def test_construct_works_as_version_id_col(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + version_id: Mapped[int] = mapped_column() + + __mapper_args__ = {"version_id_col": version_id} + + decl_base.registry.configure() + is_(User.__table__.c.version_id, User.__mapper__.version_id_col) + + def test_construct_works_in_deferred(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = deferred(mapped_column()) + + self.assert_compile(select(User), "SELECT users.id FROM users") + self.assert_compile( + select(User).options(undefer(User.data)), + "SELECT users.data, users.id FROM users", + ) + + def test_deferred_kw(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(deferred=True) + + self.assert_compile(select(User), "SELECT users.id FROM users") + self.assert_compile( + select(User).options(undefer(User.data)), + "SELECT users.data, users.id FROM users", + ) + + +class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + def test_mapped_column_omit_fn(self, decl_base): + class MixinOne: + name: Mapped[str] + x: Mapped[int] + y: Mapped[int] = mapped_column() + + class A(MixinOne, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + # ordering of cols is TODO + eq_(A.__table__.c.keys(), ["id", "y", "name", "x"]) + + def test_mc_duplication_plain(self, decl_base): + class MixinOne: + name: Mapped[str] = mapped_column() + + class A(MixinOne, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + class B(MixinOne, decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + is_not(A.__table__.c.name, B.__table__.c.name) + + def test_mc_duplication_declared_attr(self, decl_base): + class MixinOne: + @declared_attr + def name(cls) -> Mapped[str]: + return mapped_column() + + class A(MixinOne, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + class B(MixinOne, decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + is_not(A.__table__.c.name, B.__table__.c.name) + + def test_relationship_requires_declared_attr(self, decl_base): + class Related(decl_base): + __tablename__ = "related" + + id: Mapped[int] = mapped_column(primary_key=True) + + class HasRelated: + related_id: Mapped[int] = mapped_column(ForeignKey(Related.id)) + + related: Mapped[Related] = relationship() + + with expect_raises_message( + sa_exc.InvalidRequestError, + r"Mapper properties \(i.e. deferred,column_property\(\), " + r"relationship\(\), etc.\) must be declared", + ): + + class A(HasRelated, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + def test_relationship_duplication_declared_attr(self, decl_base): + class Related(decl_base): + __tablename__ = "related" + + id: Mapped[int] = mapped_column(primary_key=True) + + class HasRelated: + related_id: Mapped[int] = mapped_column(ForeignKey(Related.id)) + + @declared_attr + def related(cls) -> Mapped[Related]: + return relationship() + + class A(HasRelated, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + class B(HasRelated, decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + self.assert_compile( + select(A).join(A.related), + "SELECT a.id, a.related_id FROM a " + "JOIN related ON related.id = a.related_id", + ) + self.assert_compile( + select(B).join(B.related), + "SELECT b.id, b.related_id FROM b " + "JOIN related ON related.id = b.related_id", + ) + + +class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + @testing.fixture + def decl_base(self): + class Base(DeclarativeBase): + pass + + yield Base + Base.registry.dispose() + + def test_no_typing_in_rhs(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + bs = relationship("List['B']") + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + with expect_raises_message( + sa_exc.InvalidRequestError, + r"When initializing mapper Mapper\[A\(a\)\], expression " + r'"relationship\(\"List\[\'B\'\]\"\)\" seems to be using a ' + r"generic class as the argument to relationship\(\); please " + r"state the generic argument using an annotation, e.g. " + r'"bs: Mapped\[List\[\'B\'\]\] = relationship\(\)"', + ): + + decl_base.registry.configure() + + def test_required_no_arg(self, decl_base): + with expect_raises_message( + sa_exc.ArgumentError, + r"Python typing annotation is required for attribute " + r'"A.bs" when primary ' + r'argument\(s\) for "Relationship" construct are None or ' + r"not present", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + bs = relationship() + + def test_rudimentary_dataclasses_support(self, registry): + @registry.mapped + @dataclasses.dataclass + class A: + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + id: Mapped[int] = mapped_column(primary_key=True) + bs: List["B"] = dataclasses.field( # noqa: F821 + default_factory=list, metadata={"sa": relationship()} + ) + + @registry.mapped + @dataclasses.dataclass + class B: + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id = mapped_column(ForeignKey("a.id")) + + self.assert_compile( + select(A).join(A.bs), "SELECT a.id FROM a JOIN b ON a.id = b.a_id" + ) + + def test_basic_bidirectional(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs: Mapped[List["B"]] = relationship( # noqa F821 + back_populates="a" + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + a: Mapped["A"] = relationship( + back_populates="bs", primaryjoin=a_id == A.id + ) + + a1 = A(data="data") + b1 = B() + a1.bs.append(b1) + is_(a1, b1.a) + + def test_wrong_annotation_type_one(self, decl_base): + + with expect_raises_message( + sa_exc.ArgumentError, + r"Type annotation for \"A.data\" should use the " + r"syntax \"Mapped\['B'\]\" or \"Relationship\['B'\]\"", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: "B" = relationship() # type: ignore # noqa + + def test_wrong_annotation_type_two(self, decl_base): + + with expect_raises_message( + sa_exc.ArgumentError, + r"Type annotation for \"A.data\" should use the " + r"syntax \"Mapped\[B\]\" or \"Relationship\[B\]\"", + ): + + class B(decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True) + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: B = relationship() # type: ignore # noqa + + def test_wrong_annotation_type_three(self, decl_base): + + with expect_raises_message( + sa_exc.ArgumentError, + r"Type annotation for \"A.data\" should use the " + r"syntax \"Mapped\['List\[B\]'\]\" or " + r"\"Relationship\['List\[B\]'\]\"", + ): + + class B(decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True) + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: "List[B]" = relationship() # type: ignore # noqa + + def test_collection_class_uselist(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs_list: Mapped[List["B"]] = relationship( # noqa F821 + viewonly=True + ) + bs_set: Mapped[Set["B"]] = relationship(viewonly=True) # noqa F821 + bs_list_warg: Mapped[List["B"]] = relationship( # noqa F821 + "B", viewonly=True + ) + bs_set_warg: Mapped[Set["B"]] = relationship( # noqa F821 + "B", viewonly=True + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + a: Mapped["A"] = relationship(viewonly=True) + a_warg: Mapped["A"] = relationship("A", viewonly=True) + + is_(A.__mapper__.attrs["bs_list"].collection_class, list) + is_(A.__mapper__.attrs["bs_set"].collection_class, set) + is_(A.__mapper__.attrs["bs_list_warg"].collection_class, list) + is_(A.__mapper__.attrs["bs_set_warg"].collection_class, set) + is_true(A.__mapper__.attrs["bs_list"].uselist) + is_true(A.__mapper__.attrs["bs_set"].uselist) + is_true(A.__mapper__.attrs["bs_list_warg"].uselist) + is_true(A.__mapper__.attrs["bs_set_warg"].uselist) + + is_false(B.__mapper__.attrs["a"].uselist) + is_false(B.__mapper__.attrs["a_warg"].uselist) + + def test_collection_class_dict_no_collection(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs: Mapped[Dict[str, "B"]] = relationship() # noqa F821 + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + name: Mapped[str] = mapped_column() + + # this is the old collections message. it's not great, but at the + # moment I like that this is what's raised + with expect_raises_message( + sa_exc.ArgumentError, + "Type InstrumentedDict must elect an appender", + ): + decl_base.registry.configure() + + def test_collection_class_dict_attr_mapped_collection(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + + bs: Mapped[MappedCollection[str, "B"]] = relationship( # noqa F821 + collection_class=attribute_mapped_collection("name") + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + name: Mapped[str] = mapped_column() + + decl_base.registry.configure() + + a1 = A() + b1 = B(name="foo") + + # collection appender on MappedCollection + a1.bs.set(b1) + + is_(a1.bs["foo"], b1) + + +class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + @testing.fixture + def dataclass_point_fixture(self, decl_base): + @dataclasses.dataclass + class Point: + x: int + y: int + + class Edge(decl_base): + __tablename__ = "edge" + id: Mapped[int] = mapped_column(primary_key=True) + graph_id: Mapped[int] = mapped_column(ForeignKey("graph.id")) + + start: Mapped[Point] = composite( + Point, mapped_column("x1"), mapped_column("y1") + ) + + end: Mapped[Point] = composite( + Point, mapped_column("x2"), mapped_column("y2") + ) + + class Graph(decl_base): + __tablename__ = "graph" + id: Mapped[int] = mapped_column(primary_key=True) + + edges: Mapped[List[Edge]] = relationship() + + decl_base.metadata.create_all(testing.db) + return Point, Graph, Edge + + def test_composite_setup(self, dataclass_point_fixture): + Point, Graph, Edge = dataclass_point_fixture + + with fixture_session() as sess: + sess.add( + Graph( + edges=[ + Edge(start=Point(1, 2), end=Point(3, 4)), + Edge(start=Point(7, 8), end=Point(5, 6)), + ] + ) + ) + sess.commit() + + self.assert_compile( + select(Edge), + "SELECT edge.id, edge.graph_id, edge.x1, edge.y1, " + "edge.x2, edge.y2 FROM edge", + ) + + with fixture_session() as sess: + g1 = sess.scalar(select(Graph)) + + # round trip! + eq_(g1.edges[0].end, Point(3, 4)) + + def test_named_setup(self, decl_base: Type[DeclarativeBase]): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + Address, mapped_column(), mapped_column(), mapped_column("zip") + ) + + decl_base.metadata.create_all(testing.db) + + with fixture_session() as sess: + sess.add( + User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + ) + sess.commit() + + with fixture_session() as sess: + u1 = sess.scalar(select(User)) + + # round trip! + eq_(u1.address, Address("123 anywhere street", "NY", "12345")) + + def test_no_fwd_ref_annotated_setup(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + with expect_raises_message( + ArgumentError, + r"Can't use forward ref ForwardRef\('Address'\) " + r"for composite class argument", + ): + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped["Address"] = composite( + mapped_column(), mapped_column(), mapped_column("zip") + ) + + def test_fwd_ref_plus_no_mapped(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + with expect_raises_message( + ArgumentError, + r"Type annotation for \"User.address\" should use the syntax " + r"\"Mapped\['Address'\]\" or \"MappedColumn\['Address'\]\"", + ): + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: "Address" = composite( # type: ignore + mapped_column(), mapped_column(), mapped_column("zip") + ) + + def test_fwd_ref_ok_explicit_cls(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped["Address"] = composite( + Address, mapped_column(), mapped_column(), mapped_column("zip") + ) + + self.assert_compile( + select(User), + 'SELECT "user".id, "user".name, "user".street, ' + '"user".state, "user".zip FROM "user"', + ) + + def test_cls_annotated_setup(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + mapped_column(), mapped_column(), mapped_column("zip") + ) + + decl_base.metadata.create_all(testing.db) + + with fixture_session() as sess: + sess.add( + User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + ) + sess.commit() + + with fixture_session() as sess: + u1 = sess.scalar(select(User)) + + # round trip! + eq_(u1.address, Address("123 anywhere street", "NY", "12345")) + + def test_one_col_setup(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite(Address, mapped_column()) + + decl_base.metadata.create_all(testing.db) + + with fixture_session() as sess: + sess.add( + User( + name="user 1", + address=Address("123 anywhere street"), + ) + ) + sess.commit() + + with fixture_session() as sess: + u1 = sess.scalar(select(User)) + + # round trip! + eq_(u1.address, Address("123 anywhere street")) + + +class AllYourFavoriteHitsTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """try a bunch of common mappings using the new style""" + + __dialect__ = "default" + + def test_employee_joined_inh(self, decl_base: Type[DeclarativeBase]): + + str50 = Annotated[str, 50] + str30 = Annotated[str, 30] + opt_str50 = Optional[str50] + + decl_base.registry.update_type_annotation_map( + {str50: String(50), str30: String(30)} + ) + + class Company(decl_base): + __tablename__ = "company" + + company_id: Mapped[int] = mapped_column(Integer, primary_key=True) + + name: Mapped[str50] + + employees: Mapped[Set["Person"]] = relationship() # noqa F821 + + class Person(decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column(primary_key=True) + company_id: Mapped[int] = mapped_column( + ForeignKey("company.company_id") + ) + name: Mapped[str50] + type: Mapped[str30] = mapped_column() + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[opt_str50] + primary_language: Mapped[opt_str50] + + class Manager(Person): + __tablename__ = "manager" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True + ) + status: Mapped[str] = mapped_column(String(30)) + manager_name: Mapped[str50] + + is_(Person.__mapper__.polymorphic_on, Person.__table__.c.type) + + # the SELECT statements here confirm the columns present and their + # ordering + self.assert_compile( + select(Person), + "SELECT person.person_id, person.company_id, person.name, " + "person.type FROM person", + ) + + self.assert_compile( + select(Manager), + "SELECT manager.person_id, person.person_id AS person_id_1, " + "person.company_id, person.name, person.type, manager.status, " + "manager.manager_name FROM person " + "JOIN manager ON person.person_id = manager.person_id", + ) + + self.assert_compile( + select(Company).join(Company.employees.of_type(Engineer)), + "SELECT company.company_id, company.name FROM company JOIN " + "(person JOIN engineer ON person.person_id = engineer.person_id) " + "ON company.company_id = person.company_id", + ) diff --git a/test/orm/declarative/test_typing_py3k.py b/test/orm/declarative/test_typing_py3k.py deleted file mode 100644 index 0be91a509..000000000 --- a/test/orm/declarative/test_typing_py3k.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Generic -from typing import Type -from typing import TypeVar - -from sqlalchemy import Column -from sqlalchemy import inspect -from sqlalchemy import Integer -from sqlalchemy.orm import as_declarative -from sqlalchemy.testing import eq_ -from sqlalchemy.testing import fixtures -from sqlalchemy.testing import is_ -from sqlalchemy.testing.assertions import expect_raises - - -class DeclarativeBaseTest(fixtures.TestBase): - def test_class_getitem(self): - T = TypeVar("T", bound="CommonBase") # noqa - - class CommonBase(Generic[T]): - @classmethod - def boring(cls: Type[T]) -> Type[T]: - return cls - - @classmethod - def more_boring(cls: Type[T]) -> int: - return 27 - - @as_declarative() - class Base(CommonBase[T]): - foo = 1 - - class Tab(Base["Tab"]): - __tablename__ = "foo" - a = Column(Integer, primary_key=True) - - eq_(Tab.foo, 1) - is_(Tab.__table__, inspect(Tab).local_table) - eq_(Tab.boring(), Tab) - eq_(Tab.more_boring(), 27) - - with expect_raises(AttributeError): - Tab.non_existent diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index fae146755..c5031ed59 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import joinedload from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship -from sqlalchemy.orm.util import with_polymorphic +from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ diff --git a/test/orm/inheritance/test_poly_loading.py b/test/orm/inheritance/test_poly_loading.py index 5f8ff5639..d9d4a9a22 100644 --- a/test/orm/inheritance/test_poly_loading.py +++ b/test/orm/inheritance/test_poly_loading.py @@ -339,13 +339,13 @@ class TestGeometries(GeometryFixtureBase): testing.db, q.all, CompiledSQL( - "SELECT a.type AS a_type, a.id AS a_id, " + "SELECT a.id AS a_id, a.type AS a_type, " "a.a_data AS a_a_data FROM a", {}, ), Or( CompiledSQL( - "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " "c.c_data AS c_c_data, c.e_data AS c_e_data, " "c.d_data AS c_d_data " "FROM a JOIN c ON a.id = c.id " @@ -354,7 +354,7 @@ class TestGeometries(GeometryFixtureBase): [{"primary_keys": [1, 2]}], ), CompiledSQL( - "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " "c.c_data AS c_c_data, " "c.d_data AS c_d_data, c.e_data AS c_e_data " "FROM a JOIN c ON a.id = c.id " @@ -396,13 +396,13 @@ class TestGeometries(GeometryFixtureBase): testing.db, q.all, CompiledSQL( - "SELECT a.type AS a_type, a.id AS a_id, " + "SELECT a.id AS a_id, a.type AS a_type, " "a.a_data AS a_a_data FROM a", {}, ), Or( CompiledSQL( - "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " + "SELECT a.id AS a_id, a.type AS a_type, c.id AS c_id, " "c.c_data AS c_c_data, c.e_data AS c_e_data, " "c.d_data AS c_d_data " "FROM a JOIN c ON a.id = c.id " @@ -411,7 +411,7 @@ class TestGeometries(GeometryFixtureBase): [{"primary_keys": [1, 2]}], ), CompiledSQL( - "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " "c.c_data AS c_c_data, c.d_data AS c_d_data, " "c.e_data AS c_e_data " "FROM a JOIN c ON a.id = c.id " @@ -465,15 +465,15 @@ class TestGeometries(GeometryFixtureBase): testing.db, q.all, CompiledSQL( - "SELECT a.type AS a_type, a.id AS a_id, " + "SELECT a.id AS a_id, a.type AS a_type, " "a.a_data AS a_a_data FROM a ORDER BY a.id", {}, ), Or( # here, the test is that the adaptation of "a" takes place CompiledSQL( - "SELECT poly.a_type AS poly_a_type, " - "poly.c_id AS poly_c_id, " + "SELECT poly.c_id AS poly_c_id, " + "poly.a_type AS poly_a_type, " "poly.a_id AS poly_a_id, poly.c_c_data AS poly_c_c_data, " "poly.e_id AS poly_e_id, poly.e_e_data AS poly_e_e_data, " "poly.d_id AS poly_d_id, poly.d_d_data AS poly_d_d_data " @@ -489,9 +489,9 @@ class TestGeometries(GeometryFixtureBase): [{"primary_keys": [1, 2]}], ), CompiledSQL( - "SELECT poly.a_type AS poly_a_type, " - "poly.c_id AS poly_c_id, " - "poly.a_id AS poly_a_id, poly.c_c_data AS poly_c_c_data, " + "SELECT poly.c_id AS poly_c_id, " + "poly.a_id AS poly_a_id, poly.a_type AS poly_a_type, " + "poly.c_c_data AS poly_c_c_data, " "poly.d_id AS poly_d_id, poly.d_d_data AS poly_d_d_data, " "poly.e_id AS poly_e_id, poly.e_e_data AS poly_e_e_data " "FROM (SELECT a.id AS a_id, a.type AS a_type, " diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 19e090e0e..f41947b6c 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -6,8 +6,8 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.orm import aliased +from sqlalchemy.orm import Composite from sqlalchemy.orm import composite -from sqlalchemy.orm import CompositeProperty from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import relationship from sqlalchemy.orm import Session @@ -1105,7 +1105,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): if custom: - class CustomComparator(sa.orm.CompositeProperty.Comparator): + class CustomComparator(sa.orm.Composite.Comparator): def near(self, other, d): clauses = self.__clause_element__().clauses diff_x = clauses[0] - other.x @@ -1163,7 +1163,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): Edge = self.classes.Edge start_prop = Edge.start.property - assert start_prop.comparator_factory is CompositeProperty.Comparator + assert start_prop.comparator_factory is Composite.Comparator def test_custom_comparator_factory(self): self._fixture(True) diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py index 706024eb5..13cd6dec5 100644 --- a/test/orm/test_dataclasses_py3k.py +++ b/test/orm/test_dataclasses_py3k.py @@ -271,7 +271,9 @@ class PlainDeclarativeDataclassesTest(DataclassesTest): widgets: List[Widget] = dataclasses.field(default_factory=list) widget_count: int = dataclasses.field(init=False) - widgets = relationship("Widget") + __mapper_args__ = dict( + properties=dict(widgets=relationship("Widget")) + ) def __post_init__(self): self.widget_count = len(self.widgets) @@ -912,7 +914,7 @@ class PropagationFromMixinTest(fixtures.TestBase): eq_(BaseType.__table__.name, "basetype") eq_( list(BaseType.__table__.c.keys()), - ["timestamp", "type", "id", "value"], + ["type", "id", "value", "timestamp"], ) eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"}) assert Single.__table__ is BaseType.__table__ diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index de211cf63..1fad974b9 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -25,6 +25,7 @@ from sqlalchemy.orm import Load from sqlalchemy.orm import load_only from sqlalchemy.orm import reconstructor from sqlalchemy.orm import registry +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import synonym @@ -2896,12 +2897,10 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.classes.User, ) - from sqlalchemy.orm.relationships import RelationshipProperty - # NOTE: this API changed in 0.8, previously __clause_element__() # gave the parent selecatable, now it gives the # primaryjoin/secondaryjoin - class MyFactory(RelationshipProperty.Comparator): + class MyFactory(Relationship.Comparator): __hash__ = None def __eq__(self, other): @@ -2909,7 +2908,7 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._source_selectable().c.user_id ) == func.foobar(other.id) - class MyFactory2(RelationshipProperty.Comparator): + class MyFactory2(Relationship.Comparator): __hash__ = None def __eq__(self, other): diff --git a/test/orm/test_options.py b/test/orm/test_options.py index e74ffeced..96759e388 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -930,7 +930,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): [Item], lambda: (load_only(Item.keywords),), 'Can\'t apply "column loader" strategy to property ' - '"Item.keywords", which is a "relationship property"; this ' + '"Item.keywords", which is a "relationship"; this ' 'loader strategy is intended to be used with a "column property".', ) @@ -942,7 +942,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): lambda: (joinedload(Keyword.id).joinedload(Item.keywords),), 'Can\'t apply "joined loader" strategy to property "Keyword.id", ' 'which is a "column property"; this loader strategy is intended ' - 'to be used with a "relationship property".', + 'to be used with a "relationship".', ) def test_option_against_wrong_multi_entity_type_attr_two(self): @@ -953,7 +953,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): lambda: (joinedload(Keyword.keywords).joinedload(Item.keywords),), 'Can\'t apply "joined loader" strategy to property ' '"Keyword.keywords", which is a "column property"; this loader ' - 'strategy is intended to be used with a "relationship property".', + 'strategy is intended to be used with a "relationship".', ) def test_option_against_wrong_multi_entity_type_attr_three(self): diff --git a/test/orm/test_query.py b/test/orm/test_query.py index e7fdf661a..d0c8f4108 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -50,6 +50,7 @@ from sqlalchemy.orm import column_property from sqlalchemy.orm import contains_eager from sqlalchemy.orm import defer from sqlalchemy.orm import deferred +from sqlalchemy.orm import join from sqlalchemy.orm import joinedload from sqlalchemy.orm import lazyload from sqlalchemy.orm import Query @@ -59,9 +60,8 @@ from sqlalchemy.orm import Session from sqlalchemy.orm import subqueryload from sqlalchemy.orm import synonym from sqlalchemy.orm import undefer +from sqlalchemy.orm import with_parent from sqlalchemy.orm.context import QueryContext -from sqlalchemy.orm.util import join -from sqlalchemy.orm.util import with_parent from sqlalchemy.sql import expression from sqlalchemy.sql import operators from sqlalchemy.testing import AssertsCompiledSQL diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index e02f0e2ed..092494165 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -2528,7 +2528,6 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): assert t2.c.x.references(t1.c.x) def test_create_drop_schema(self): - self.assert_compile( schema.CreateSchema("sa_schema"), "CREATE SCHEMA sa_schema" ) @@ -146,7 +146,6 @@ deps= importlib_metadata; python_version < '3.8' mypy patch==1.* - git+https://github.com/sqlalchemy/sqlalchemy2-stubs commands = pytest -m mypy {posargs} |