diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-07 12:37:23 -0400 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2022-04-12 02:09:50 +0000 |
| commit | aa9cd878e8249a4a758c7f968e929e92fede42a5 (patch) | |
| tree | 1be1c9dc24dd247a150be55d65bfc56ebaf111bc /test | |
| parent | 98eae4e181cb2d1bbc67ec834bfad29dcba7f461 (diff) | |
| download | sqlalchemy-aa9cd878e8249a4a758c7f968e929e92fede42a5.tar.gz | |
pep-484: session, instancestate, etc
Also adds some fixes to annotation-based mapping
that have come up, as well as starts to add more
pep-484 test cases
Change-Id: Ia722bbbc7967a11b23b66c8084eb61df9d233fee
Diffstat (limited to 'test')
| -rw-r--r-- | test/base/test_utils.py | 47 | ||||
| -rw-r--r-- | test/ext/mypy/plain_files/session.py | 50 | ||||
| -rw-r--r-- | test/orm/declarative/test_tm_future_annotations.py | 49 | ||||
| -rw-r--r-- | test/orm/test_core_compilation.py | 6 | ||||
| -rw-r--r-- | test/orm/test_events.py | 81 | ||||
| -rw-r--r-- | test/orm/test_pickled.py | 68 | ||||
| -rw-r--r-- | test/orm/test_scoping.py | 3 |
7 files changed, 222 insertions, 82 deletions
diff --git a/test/base/test_utils.py b/test/base/test_utils.py index e22340da6..c5a47ddf9 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -27,6 +27,7 @@ from sqlalchemy.testing.util import gc_collect from sqlalchemy.testing.util import picklers from sqlalchemy.util import classproperty from sqlalchemy.util import compat +from sqlalchemy.util import FastIntFlag from sqlalchemy.util import get_callable_argspec from sqlalchemy.util import langhelpers from sqlalchemy.util import preloaded @@ -2300,6 +2301,20 @@ class SymbolTest(fixtures.TestBase): assert sym1 is not sym3 assert sym1 != sym3 + def test_fast_int_flag(self): + class Enum(FastIntFlag): + sym1 = 1 + sym2 = 2 + + sym3 = 3 + + assert Enum.sym1 is not Enum.sym3 + assert Enum.sym1 != Enum.sym3 + + assert Enum.sym1.name == "sym1" + + eq_(list(Enum), [Enum.sym1, Enum.sym2, Enum.sym3]) + def test_pickle(self): sym1 = util.symbol("foo") sym2 = util.symbol("foo") @@ -2338,17 +2353,19 @@ class SymbolTest(fixtures.TestBase): assert (sym1 | sym2) & (sym2 | sym4) def test_parser(self): - sym1 = util.symbol("sym1", canonical=1) - sym2 = util.symbol("sym2", canonical=2) - sym3 = util.symbol("sym3", canonical=4) - sym4 = util.symbol("sym4", canonical=8) + class MyEnum(FastIntFlag): + sym1 = 1 + sym2 = 2 + sym3 = 4 + sym4 = 8 + sym1, sym2, sym3, sym4 = tuple(MyEnum) lookup_one = {sym1: [], sym2: [True], sym3: [False], sym4: [None]} lookup_two = {sym1: [], sym2: [True], sym3: [False]} lookup_three = {sym1: [], sym2: ["symbol2"], sym3: []} is_( - util.symbol.parse_user_argument( + langhelpers.parse_user_argument_for_enum( "sym2", lookup_one, "some_name", resolve_symbol_names=True ), sym2, @@ -2357,35 +2374,41 @@ class SymbolTest(fixtures.TestBase): assert_raises_message( exc.ArgumentError, "Invalid value for 'some_name': 'sym2'", - util.symbol.parse_user_argument, + langhelpers.parse_user_argument_for_enum, "sym2", lookup_one, "some_name", ) is_( - util.symbol.parse_user_argument( + langhelpers.parse_user_argument_for_enum( True, lookup_one, "some_name", resolve_symbol_names=False ), sym2, ) is_( - util.symbol.parse_user_argument(sym2, lookup_one, "some_name"), + langhelpers.parse_user_argument_for_enum( + sym2, lookup_one, "some_name" + ), sym2, ) is_( - util.symbol.parse_user_argument(None, lookup_one, "some_name"), + langhelpers.parse_user_argument_for_enum( + None, lookup_one, "some_name" + ), sym4, ) is_( - util.symbol.parse_user_argument(None, lookup_two, "some_name"), + langhelpers.parse_user_argument_for_enum( + None, lookup_two, "some_name" + ), None, ) is_( - util.symbol.parse_user_argument( + langhelpers.parse_user_argument_for_enum( "symbol2", lookup_three, "some_name" ), sym2, @@ -2394,7 +2417,7 @@ class SymbolTest(fixtures.TestBase): assert_raises_message( exc.ArgumentError, "Invalid value for 'some_name': 'foo'", - util.symbol.parse_user_argument, + langhelpers.parse_user_argument_for_enum, "foo", lookup_three, "some_name", diff --git a/test/ext/mypy/plain_files/session.py b/test/ext/mypy/plain_files/session.py new file mode 100644 index 000000000..24c685e84 --- /dev/null +++ b/test/ext/mypy/plain_files/session.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import List +from typing import Sequence + +from sqlalchemy import create_engine +from sqlalchemy import ForeignKey +from sqlalchemy import select +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + addresses: Mapped[List[Address]] = relationship(back_populates="user") + + +class Address(Base): + __tablename__ = "address" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email: Mapped[str] + + user: Mapped[User] = relationship(back_populates="addresses") + + +e = create_engine("sqlite://") +Base.metadata.create_all(e) + +with Session(e) as sess: + u1 = User(name="u1") + sess.add(u1) + sess.add_all([Address(user=u1, email="e1"), Address(user=u1, email="e2")]) + sess.commit() + +with Session(e) as sess: + users: Sequence[User] = sess.scalars( + select(User), execution_options={"stream_results": False} + ).all() diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index c7022dc31..f8abd686a 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -1,9 +1,56 @@ from __future__ import annotations +from typing import List + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.testing import is_ from .test_typed_mapping import MappedColumnTest # noqa -from .test_typed_mapping import RelationshipLHSTest # noqa +from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest """runs the annotation-sensitive tests from test_typed_mappings while having ``from __future__ import annotations`` in effect. """ + + +class RelationshipLHSTest(_RelationshipLHSTest): + def test_bidirectional_literal_annotations(self, decl_base): + """test the 'string cleanup' function in orm/util.py, where + we receive a string annotation like:: + + "Mapped[List[B]]" + + Which then fails to evaluate because we don't have "B" yet. + The annotation is converted on the fly to:: + + 'Mapped[List["B"]]' + + so that when we evaluated it, we get ``Mapped[List["B"]]`` and + can extract "B" as a string. + + """ + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs: Mapped[List[B]] = relationship(back_populates="a") + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + a: Mapped[A] = relationship( + back_populates="bs", primaryjoin=a_id == A.id + ) + + a1 = A(data="data") + b1 = B() + a1.bs.append(b1) + is_(a1, b1.a) diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index d6d229f79..058e1735b 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -190,6 +190,7 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): }, ], ), + argnames="cols, expected", ) def test_column_descriptions(self, cols, expected): User, Address = self.classes("User", "Address") @@ -211,8 +212,13 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): ) stmt = select(*cols) + eq_(stmt.column_descriptions, expected) + if stmt._propagate_attrs: + stmt = select(*cols).from_statement(stmt) + eq_(stmt.column_descriptions, expected) + @testing.combinations(insert, update, delete, argnames="dml_construct") @testing.combinations( ( diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 79b20e285..4cecac0de 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -5,7 +5,9 @@ from unittest.mock import Mock import sqlalchemy as sa from sqlalchemy import delete from sqlalchemy import event +from sqlalchemy import exc as sa_exc from sqlalchemy import ForeignKey +from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal_column @@ -42,6 +44,7 @@ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_not +from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -236,6 +239,84 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): ), ) + def test_override_parameters_executesingle(self): + User = self.classes.User + + sess = Session(testing.db, future=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + return ctx.invoke_statement(params={"name": "overridden"}) + + orig_params = {"id": 18, "name": "original"} + with self.sql_execution_asserter() as asserter: + sess.execute(insert(User), orig_params) + asserter.assert_( + CompiledSQL( + "INSERT INTO users (id, name) VALUES (:id, :name)", + [{"id": 18, "name": "overridden"}], + ) + ) + # orig params weren't mutated + eq_(orig_params, {"id": 18, "name": "original"}) + + def test_override_parameters_executemany(self): + User = self.classes.User + + sess = Session(testing.db, future=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + return ctx.invoke_statement( + params=[{"name": "overridden1"}, {"name": "overridden2"}] + ) + + orig_params = [ + {"id": 18, "name": "original1"}, + {"id": 19, "name": "original2"}, + ] + with self.sql_execution_asserter() as asserter: + sess.execute(insert(User), orig_params) + asserter.assert_( + CompiledSQL( + "INSERT INTO users (id, name) VALUES (:id, :name)", + [ + {"id": 18, "name": "overridden1"}, + {"id": 19, "name": "overridden2"}, + ], + ) + ) + # orig params weren't mutated + eq_( + orig_params, + [{"id": 18, "name": "original1"}, {"id": 19, "name": "original2"}], + ) + + def test_override_parameters_executemany_mismatch(self): + User = self.classes.User + + sess = Session(testing.db, future=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + return ctx.invoke_statement( + params=[{"name": "overridden1"}, {"name": "overridden2"}] + ) + + orig_params = [ + {"id": 18, "name": "original1"}, + {"id": 19, "name": "original2"}, + {"id": 20, "name": "original3"}, + ] + with expect_raises_message( + sa_exc.InvalidRequestError, + r"Can't apply executemany parameters to statement; number " + r"of parameter sets passed to Session.execute\(\) \(3\) does " + r"not match number of parameter sets given to " + r"ORMExecuteState.invoke_statement\(\) \(2\)", + ): + sess.execute(insert(User), orig_params) + def test_chained_events_one(self): sess = Session(testing.db, future=True) diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index a4250e375..c006babc8 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -11,7 +11,6 @@ from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import exc as orm_exc -from sqlalchemy.orm import instrumentation from sqlalchemy.orm import lazyload from sqlalchemy.orm import relationship from sqlalchemy.orm import state as sa_state @@ -410,73 +409,6 @@ class PickleTest(fixtures.MappedTest): u2 = loads(dumps(u1)) eq_(u1, u2) - def test_09_pickle(self): - users = self.tables.users - self.mapper_registry.map_imperatively(User, users) - sess = fixture_session() - sess.add(User(id=1, name="ed")) - sess.commit() - sess.close() - - inst = User(id=1, name="ed") - del inst._sa_instance_state - - state = sa_state.InstanceState.__new__(sa_state.InstanceState) - state_09 = { - "class_": User, - "modified": False, - "committed_state": {}, - "instance": inst, - "callables": {"name": state, "id": state}, - "key": (User, (1,)), - "expired": True, - } - manager = instrumentation._SerializeManager.__new__( - instrumentation._SerializeManager - ) - manager.class_ = User - state_09["manager"] = manager - state.__setstate__(state_09) - eq_(state.expired_attributes, {"name", "id"}) - - sess = fixture_session() - sess.add(inst) - eq_(inst.name, "ed") - # test identity_token expansion - eq_(sa.inspect(inst).key, (User, (1,), None)) - - def test_11_pickle(self): - users = self.tables.users - self.mapper_registry.map_imperatively(User, users) - sess = fixture_session() - u1 = User(id=1, name="ed") - sess.add(u1) - sess.commit() - - sess.close() - - manager = instrumentation._SerializeManager.__new__( - instrumentation._SerializeManager - ) - manager.class_ = User - - state_11 = { - "class_": User, - "modified": False, - "committed_state": {}, - "instance": u1, - "manager": manager, - "key": (User, (1,)), - "expired_attributes": set(), - "expired": True, - } - - state = sa_state.InstanceState.__new__(sa_state.InstanceState) - state.__setstate__(state_11) - - eq_(state.identity_token, None) - eq_(state.identity_key, (User, (1,), None)) - def test_state_info_pickle(self): users = self.tables.users self.mapper_registry.map_imperatively(User, users) diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index f2d7d8569..33e66d52f 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -5,6 +5,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import util from sqlalchemy.orm import query from sqlalchemy.orm import relationship from sqlalchemy.orm import scoped_session @@ -158,7 +159,7 @@ class ScopedSessionTest(fixtures.MappedTest): populate_existing=False, with_for_update=None, identity_token=None, - execution_options=None, + execution_options=util.EMPTY_DICT, ), ], ) |
