summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-07 12:37:23 -0400
committermike bayer <mike_mp@zzzcomputing.com>2022-04-12 02:09:50 +0000
commitaa9cd878e8249a4a758c7f968e929e92fede42a5 (patch)
tree1be1c9dc24dd247a150be55d65bfc56ebaf111bc /test
parent98eae4e181cb2d1bbc67ec834bfad29dcba7f461 (diff)
downloadsqlalchemy-aa9cd878e8249a4a758c7f968e929e92fede42a5.tar.gz
pep-484: session, instancestate, etc
Also adds some fixes to annotation-based mapping that have come up, as well as starts to add more pep-484 test cases Change-Id: Ia722bbbc7967a11b23b66c8084eb61df9d233fee
Diffstat (limited to 'test')
-rw-r--r--test/base/test_utils.py47
-rw-r--r--test/ext/mypy/plain_files/session.py50
-rw-r--r--test/orm/declarative/test_tm_future_annotations.py49
-rw-r--r--test/orm/test_core_compilation.py6
-rw-r--r--test/orm/test_events.py81
-rw-r--r--test/orm/test_pickled.py68
-rw-r--r--test/orm/test_scoping.py3
7 files changed, 222 insertions, 82 deletions
diff --git a/test/base/test_utils.py b/test/base/test_utils.py
index e22340da6..c5a47ddf9 100644
--- a/test/base/test_utils.py
+++ b/test/base/test_utils.py
@@ -27,6 +27,7 @@ from sqlalchemy.testing.util import gc_collect
from sqlalchemy.testing.util import picklers
from sqlalchemy.util import classproperty
from sqlalchemy.util import compat
+from sqlalchemy.util import FastIntFlag
from sqlalchemy.util import get_callable_argspec
from sqlalchemy.util import langhelpers
from sqlalchemy.util import preloaded
@@ -2300,6 +2301,20 @@ class SymbolTest(fixtures.TestBase):
assert sym1 is not sym3
assert sym1 != sym3
+ def test_fast_int_flag(self):
+ class Enum(FastIntFlag):
+ sym1 = 1
+ sym2 = 2
+
+ sym3 = 3
+
+ assert Enum.sym1 is not Enum.sym3
+ assert Enum.sym1 != Enum.sym3
+
+ assert Enum.sym1.name == "sym1"
+
+ eq_(list(Enum), [Enum.sym1, Enum.sym2, Enum.sym3])
+
def test_pickle(self):
sym1 = util.symbol("foo")
sym2 = util.symbol("foo")
@@ -2338,17 +2353,19 @@ class SymbolTest(fixtures.TestBase):
assert (sym1 | sym2) & (sym2 | sym4)
def test_parser(self):
- sym1 = util.symbol("sym1", canonical=1)
- sym2 = util.symbol("sym2", canonical=2)
- sym3 = util.symbol("sym3", canonical=4)
- sym4 = util.symbol("sym4", canonical=8)
+ class MyEnum(FastIntFlag):
+ sym1 = 1
+ sym2 = 2
+ sym3 = 4
+ sym4 = 8
+ sym1, sym2, sym3, sym4 = tuple(MyEnum)
lookup_one = {sym1: [], sym2: [True], sym3: [False], sym4: [None]}
lookup_two = {sym1: [], sym2: [True], sym3: [False]}
lookup_three = {sym1: [], sym2: ["symbol2"], sym3: []}
is_(
- util.symbol.parse_user_argument(
+ langhelpers.parse_user_argument_for_enum(
"sym2", lookup_one, "some_name", resolve_symbol_names=True
),
sym2,
@@ -2357,35 +2374,41 @@ class SymbolTest(fixtures.TestBase):
assert_raises_message(
exc.ArgumentError,
"Invalid value for 'some_name': 'sym2'",
- util.symbol.parse_user_argument,
+ langhelpers.parse_user_argument_for_enum,
"sym2",
lookup_one,
"some_name",
)
is_(
- util.symbol.parse_user_argument(
+ langhelpers.parse_user_argument_for_enum(
True, lookup_one, "some_name", resolve_symbol_names=False
),
sym2,
)
is_(
- util.symbol.parse_user_argument(sym2, lookup_one, "some_name"),
+ langhelpers.parse_user_argument_for_enum(
+ sym2, lookup_one, "some_name"
+ ),
sym2,
)
is_(
- util.symbol.parse_user_argument(None, lookup_one, "some_name"),
+ langhelpers.parse_user_argument_for_enum(
+ None, lookup_one, "some_name"
+ ),
sym4,
)
is_(
- util.symbol.parse_user_argument(None, lookup_two, "some_name"),
+ langhelpers.parse_user_argument_for_enum(
+ None, lookup_two, "some_name"
+ ),
None,
)
is_(
- util.symbol.parse_user_argument(
+ langhelpers.parse_user_argument_for_enum(
"symbol2", lookup_three, "some_name"
),
sym2,
@@ -2394,7 +2417,7 @@ class SymbolTest(fixtures.TestBase):
assert_raises_message(
exc.ArgumentError,
"Invalid value for 'some_name': 'foo'",
- util.symbol.parse_user_argument,
+ langhelpers.parse_user_argument_for_enum,
"foo",
lookup_three,
"some_name",
diff --git a/test/ext/mypy/plain_files/session.py b/test/ext/mypy/plain_files/session.py
new file mode 100644
index 000000000..24c685e84
--- /dev/null
+++ b/test/ext/mypy/plain_files/session.py
@@ -0,0 +1,50 @@
+from __future__ import annotations
+
+from typing import List
+from typing import Sequence
+
+from sqlalchemy import create_engine
+from sqlalchemy import ForeignKey
+from sqlalchemy import select
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class User(Base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str]
+ addresses: Mapped[List[Address]] = relationship(back_populates="user")
+
+
+class Address(Base):
+ __tablename__ = "address"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ user_id = mapped_column(ForeignKey("user.id"))
+ email: Mapped[str]
+
+ user: Mapped[User] = relationship(back_populates="addresses")
+
+
+e = create_engine("sqlite://")
+Base.metadata.create_all(e)
+
+with Session(e) as sess:
+ u1 = User(name="u1")
+ sess.add(u1)
+ sess.add_all([Address(user=u1, email="e1"), Address(user=u1, email="e2")])
+ sess.commit()
+
+with Session(e) as sess:
+ users: Sequence[User] = sess.scalars(
+ select(User), execution_options={"stream_results": False}
+ ).all()
diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py
index c7022dc31..f8abd686a 100644
--- a/test/orm/declarative/test_tm_future_annotations.py
+++ b/test/orm/declarative/test_tm_future_annotations.py
@@ -1,9 +1,56 @@
from __future__ import annotations
+from typing import List
+
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.testing import is_
from .test_typed_mapping import MappedColumnTest # noqa
-from .test_typed_mapping import RelationshipLHSTest # noqa
+from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
"""runs the annotation-sensitive tests from test_typed_mappings while
having ``from __future__ import annotations`` in effect.
"""
+
+
+class RelationshipLHSTest(_RelationshipLHSTest):
+ def test_bidirectional_literal_annotations(self, decl_base):
+ """test the 'string cleanup' function in orm/util.py, where
+ we receive a string annotation like::
+
+ "Mapped[List[B]]"
+
+ Which then fails to evaluate because we don't have "B" yet.
+ The annotation is converted on the fly to::
+
+ 'Mapped[List["B"]]'
+
+ so that when we evaluated it, we get ``Mapped[List["B"]]`` and
+ can extract "B" as a string.
+
+ """
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column()
+ bs: Mapped[List[B]] = relationship(back_populates="a")
+
+ class B(decl_base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+ a: Mapped[A] = relationship(
+ back_populates="bs", primaryjoin=a_id == A.id
+ )
+
+ a1 = A(data="data")
+ b1 = B()
+ a1.bs.append(b1)
+ is_(a1, b1.a)
diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py
index d6d229f79..058e1735b 100644
--- a/test/orm/test_core_compilation.py
+++ b/test/orm/test_core_compilation.py
@@ -190,6 +190,7 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
},
],
),
+ argnames="cols, expected",
)
def test_column_descriptions(self, cols, expected):
User, Address = self.classes("User", "Address")
@@ -211,8 +212,13 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
)
stmt = select(*cols)
+
eq_(stmt.column_descriptions, expected)
+ if stmt._propagate_attrs:
+ stmt = select(*cols).from_statement(stmt)
+ eq_(stmt.column_descriptions, expected)
+
@testing.combinations(insert, update, delete, argnames="dml_construct")
@testing.combinations(
(
diff --git a/test/orm/test_events.py b/test/orm/test_events.py
index 79b20e285..4cecac0de 100644
--- a/test/orm/test_events.py
+++ b/test/orm/test_events.py
@@ -5,7 +5,9 @@ from unittest.mock import Mock
import sqlalchemy as sa
from sqlalchemy import delete
from sqlalchemy import event
+from sqlalchemy import exc as sa_exc
from sqlalchemy import ForeignKey
+from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal_column
@@ -42,6 +44,7 @@ from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_not
+from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
@@ -236,6 +239,84 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
),
)
+ def test_override_parameters_executesingle(self):
+ User = self.classes.User
+
+ sess = Session(testing.db, future=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def one(ctx):
+ return ctx.invoke_statement(params={"name": "overridden"})
+
+ orig_params = {"id": 18, "name": "original"}
+ with self.sql_execution_asserter() as asserter:
+ sess.execute(insert(User), orig_params)
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO users (id, name) VALUES (:id, :name)",
+ [{"id": 18, "name": "overridden"}],
+ )
+ )
+ # orig params weren't mutated
+ eq_(orig_params, {"id": 18, "name": "original"})
+
+ def test_override_parameters_executemany(self):
+ User = self.classes.User
+
+ sess = Session(testing.db, future=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def one(ctx):
+ return ctx.invoke_statement(
+ params=[{"name": "overridden1"}, {"name": "overridden2"}]
+ )
+
+ orig_params = [
+ {"id": 18, "name": "original1"},
+ {"id": 19, "name": "original2"},
+ ]
+ with self.sql_execution_asserter() as asserter:
+ sess.execute(insert(User), orig_params)
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO users (id, name) VALUES (:id, :name)",
+ [
+ {"id": 18, "name": "overridden1"},
+ {"id": 19, "name": "overridden2"},
+ ],
+ )
+ )
+ # orig params weren't mutated
+ eq_(
+ orig_params,
+ [{"id": 18, "name": "original1"}, {"id": 19, "name": "original2"}],
+ )
+
+ def test_override_parameters_executemany_mismatch(self):
+ User = self.classes.User
+
+ sess = Session(testing.db, future=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def one(ctx):
+ return ctx.invoke_statement(
+ params=[{"name": "overridden1"}, {"name": "overridden2"}]
+ )
+
+ orig_params = [
+ {"id": 18, "name": "original1"},
+ {"id": 19, "name": "original2"},
+ {"id": 20, "name": "original3"},
+ ]
+ with expect_raises_message(
+ sa_exc.InvalidRequestError,
+ r"Can't apply executemany parameters to statement; number "
+ r"of parameter sets passed to Session.execute\(\) \(3\) does "
+ r"not match number of parameter sets given to "
+ r"ORMExecuteState.invoke_statement\(\) \(2\)",
+ ):
+ sess.execute(insert(User), orig_params)
+
def test_chained_events_one(self):
sess = Session(testing.db, future=True)
diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py
index a4250e375..c006babc8 100644
--- a/test/orm/test_pickled.py
+++ b/test/orm/test_pickled.py
@@ -11,7 +11,6 @@ from sqlalchemy.orm import aliased
from sqlalchemy.orm import attributes
from sqlalchemy.orm import clear_mappers
from sqlalchemy.orm import exc as orm_exc
-from sqlalchemy.orm import instrumentation
from sqlalchemy.orm import lazyload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import state as sa_state
@@ -410,73 +409,6 @@ class PickleTest(fixtures.MappedTest):
u2 = loads(dumps(u1))
eq_(u1, u2)
- def test_09_pickle(self):
- users = self.tables.users
- self.mapper_registry.map_imperatively(User, users)
- sess = fixture_session()
- sess.add(User(id=1, name="ed"))
- sess.commit()
- sess.close()
-
- inst = User(id=1, name="ed")
- del inst._sa_instance_state
-
- state = sa_state.InstanceState.__new__(sa_state.InstanceState)
- state_09 = {
- "class_": User,
- "modified": False,
- "committed_state": {},
- "instance": inst,
- "callables": {"name": state, "id": state},
- "key": (User, (1,)),
- "expired": True,
- }
- manager = instrumentation._SerializeManager.__new__(
- instrumentation._SerializeManager
- )
- manager.class_ = User
- state_09["manager"] = manager
- state.__setstate__(state_09)
- eq_(state.expired_attributes, {"name", "id"})
-
- sess = fixture_session()
- sess.add(inst)
- eq_(inst.name, "ed")
- # test identity_token expansion
- eq_(sa.inspect(inst).key, (User, (1,), None))
-
- def test_11_pickle(self):
- users = self.tables.users
- self.mapper_registry.map_imperatively(User, users)
- sess = fixture_session()
- u1 = User(id=1, name="ed")
- sess.add(u1)
- sess.commit()
-
- sess.close()
-
- manager = instrumentation._SerializeManager.__new__(
- instrumentation._SerializeManager
- )
- manager.class_ = User
-
- state_11 = {
- "class_": User,
- "modified": False,
- "committed_state": {},
- "instance": u1,
- "manager": manager,
- "key": (User, (1,)),
- "expired_attributes": set(),
- "expired": True,
- }
-
- state = sa_state.InstanceState.__new__(sa_state.InstanceState)
- state.__setstate__(state_11)
-
- eq_(state.identity_token, None)
- eq_(state.identity_key, (User, (1,), None))
-
def test_state_info_pickle(self):
users = self.tables.users
self.mapper_registry.map_imperatively(User, users)
diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py
index f2d7d8569..33e66d52f 100644
--- a/test/orm/test_scoping.py
+++ b/test/orm/test_scoping.py
@@ -5,6 +5,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy import util
from sqlalchemy.orm import query
from sqlalchemy.orm import relationship
from sqlalchemy.orm import scoped_session
@@ -158,7 +159,7 @@ class ScopedSessionTest(fixtures.MappedTest):
populate_existing=False,
with_for_update=None,
identity_token=None,
- execution_options=None,
+ execution_options=util.EMPTY_DICT,
),
],
)