summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/base/test_concurrency_py3k.py103
-rwxr-xr-xtest/conftest.py5
-rw-r--r--test/dialect/postgresql/test_dialect.py4
-rw-r--r--test/dialect/postgresql/test_query.py7
-rw-r--r--test/dialect/postgresql/test_types.py17
-rw-r--r--test/engine/test_logging.py4
-rw-r--r--test/engine/test_reconnect.py16
-rw-r--r--test/engine/test_transaction.py7
-rw-r--r--test/ext/asyncio/__init__.py0
-rw-r--r--test/ext/asyncio/test_engine_py3k.py340
-rw-r--r--test/ext/asyncio/test_session_py3k.py200
-rw-r--r--test/orm/test_update_delete.py1
-rw-r--r--test/requirements.py22
-rw-r--r--test/sql/test_defaults.py2
14 files changed, 691 insertions, 37 deletions
diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py
new file mode 100644
index 000000000..10b89291e
--- /dev/null
+++ b/test/base/test_concurrency_py3k.py
@@ -0,0 +1,103 @@
+from sqlalchemy import exc
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.util import await_fallback
+from sqlalchemy.util import await_only
+from sqlalchemy.util import greenlet_spawn
+
+
+async def run1():
+ return 1
+
+
+async def run2():
+ return 2
+
+
+def go(*fns):
+ return sum(await_only(fn()) for fn in fns)
+
+
+class TestAsyncioCompat(fixtures.TestBase):
+ @async_test
+ async def test_ok(self):
+
+ eq_(await greenlet_spawn(go, run1, run2), 3)
+
+ @async_test
+ async def test_async_error(self):
+ async def err():
+ raise ValueError("an error")
+
+ with expect_raises_message(ValueError, "an error"):
+ await greenlet_spawn(go, run1, err)
+
+ @async_test
+ async def test_sync_error(self):
+ def go():
+ await_only(run1())
+ raise ValueError("sync error")
+
+ with expect_raises_message(ValueError, "sync error"):
+ await greenlet_spawn(go)
+
+ def test_await_fallback_no_greenlet(self):
+ to_await = run1()
+ await_fallback(to_await)
+
+ def test_await_only_no_greenlet(self):
+ to_await = run1()
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ r"greenlet_spawn has not been called; can't call await_\(\) here.",
+ ):
+ await_only(to_await)
+
+ # ensure no warning
+ await_fallback(to_await)
+
+ @async_test
+ async def test_await_fallback_error(self):
+ to_await = run1()
+
+ await to_await
+
+ async def inner_await():
+ nonlocal to_await
+ to_await = run1()
+ await_fallback(to_await)
+
+ def go():
+ await_fallback(inner_await())
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "greenlet_spawn has not been called and asyncio event loop",
+ ):
+ await greenlet_spawn(go)
+
+ await to_await
+
+ @async_test
+ async def test_await_only_error(self):
+ to_await = run1()
+
+ await to_await
+
+ async def inner_await():
+ nonlocal to_await
+ to_await = run1()
+ await_only(to_await)
+
+ def go():
+ await_only(inner_await())
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ r"greenlet_spawn has not been called; can't call await_\(\) here.",
+ ):
+ await greenlet_spawn(go)
+
+ await to_await
diff --git a/test/conftest.py b/test/conftest.py
index 5c6b89fde..92d3e0776 100755
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -11,6 +11,11 @@ import sys
import pytest
+
+collect_ignore_glob = []
+if sys.version_info[0] < 3:
+ collect_ignore_glob.append("*_py3k.py")
+
pytest.register_assert_rewrite("sqlalchemy.testing.assertions")
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index f6aba550e..57c243442 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -937,9 +937,7 @@ $$ LANGUAGE plpgsql;
stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric)
assert_raises(exc.InvalidRequestError, connection.execute, stmt)
- @testing.only_if(
- "postgresql >= 8.2", "requires standard_conforming_strings"
- )
+ @testing.only_on("postgresql+psycopg2")
def test_serial_integer(self):
class BITD(TypeDecorator):
impl = Integer
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py
index ffd32813c..5ab65f9e3 100644
--- a/test/dialect/postgresql/test_query.py
+++ b/test/dialect/postgresql/test_query.py
@@ -738,17 +738,14 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
def teardown_class(cls):
metadata.drop_all()
- @testing.fails_on("postgresql+pg8000", "uses positional")
+ @testing.requires.pyformat_paramstyle
def test_expression_pyformat(self):
self.assert_compile(
matchtable.c.title.match("somstr"),
"matchtable.title @@ to_tsquery(%(title_1)s" ")",
)
- @testing.fails_on("postgresql+psycopg2", "uses pyformat")
- @testing.fails_on("postgresql+pypostgresql", "uses pyformat")
- @testing.fails_on("postgresql+pygresql", "uses pyformat")
- @testing.fails_on("postgresql+psycopg2cffi", "uses pyformat")
+ @testing.requires.format_paramstyle
def test_expression_positional(self):
self.assert_compile(
matchtable.c.title.match("somstr"),
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 95486b197..503477833 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -27,6 +27,7 @@ from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import Text
from sqlalchemy import text
+from sqlalchemy import type_coerce
from sqlalchemy import TypeDecorator
from sqlalchemy import types
from sqlalchemy import Unicode
@@ -774,7 +775,12 @@ class RegClassTest(fixtures.TestBase):
regclass = cast("pg_class", postgresql.REGCLASS)
oid = self._scalar(cast(regclass, postgresql.OID))
assert isinstance(oid, int)
- eq_(self._scalar(cast(oid, postgresql.REGCLASS)), "pg_class")
+ eq_(
+ self._scalar(
+ cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS)
+ ),
+ "pg_class",
+ )
def test_cast_whereclause(self):
pga = Table(
@@ -1801,10 +1807,13 @@ class ArrayEnum(fixtures.TestBase):
testing.db,
)
+ @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
@testing.combinations(
- sqltypes.ARRAY, postgresql.ARRAY, _ArrayOfEnum, argnames="array_cls"
+ sqltypes.ARRAY,
+ postgresql.ARRAY,
+ (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
+ argnames="array_cls",
)
- @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
@testing.provide_metadata
def test_array_of_enums(self, array_cls, enum_cls, connection):
tbl = Table(
@@ -1845,6 +1854,8 @@ class ArrayEnum(fixtures.TestBase):
sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc())
eq_(connection.scalar(sel), [MyEnum.a])
+ self.metadata.drop_all(connection)
+
class ArrayJSON(fixtures.TestBase):
__backend__ = True
diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py
index af6bc1d36..624fa9005 100644
--- a/test/engine/test_logging.py
+++ b/test/engine/test_logging.py
@@ -10,8 +10,8 @@ from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import util
from sqlalchemy.sql import util as sql_util
+from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
-from sqlalchemy.testing import assert_raises_return
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import eq_regex
@@ -104,7 +104,7 @@ class LogParamsTest(fixtures.TestBase):
def test_log_positional_array(self):
with self.eng.connect() as conn:
- exc_info = assert_raises_return(
+ exc_info = assert_raises(
tsa.exc.DBAPIError,
conn.execute,
tsa.text("SELECT * FROM foo WHERE id IN :foo AND bar=:bar"),
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index d91105f41..48eb485cb 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -1356,7 +1356,14 @@ class InvalidateDuringResultTest(fixtures.TestBase):
"cx_oracle 6 doesn't allow a close like this due to open cursors",
)
@testing.fails_if(
- ["+mysqlconnector", "+mysqldb", "+cymysql", "+pymysql", "+pg8000"],
+ [
+ "+mysqlconnector",
+ "+mysqldb",
+ "+cymysql",
+ "+pymysql",
+ "+pg8000",
+ "+asyncpg",
+ ],
"Buffers the result set and doesn't check for connection close",
)
def test_invalidate_on_results(self):
@@ -1365,5 +1372,8 @@ class InvalidateDuringResultTest(fixtures.TestBase):
for x in range(20):
result.fetchone()
self.engine.test_shutdown()
- _assert_invalidated(result.fetchone)
- assert conn.invalidated
+ try:
+ _assert_invalidated(result.fetchone)
+ assert conn.invalidated
+ finally:
+ conn.invalidate()
diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py
index 8981028d2..cd144e45f 100644
--- a/test/engine/test_transaction.py
+++ b/test/engine/test_transaction.py
@@ -461,11 +461,8 @@ class TransactionTest(fixtures.TestBase):
assert not savepoint.is_active
if util.py3k:
- # driver error
- assert exc_.__cause__
-
- # and that's it, no other context
- assert not exc_.__cause__.__context__
+ # ensure cause comes from the DBAPI
+ assert isinstance(exc_.__cause__, testing.db.dialect.dbapi.Error)
def test_retains_through_options(self, local_connection):
connection = local_connection
diff --git a/test/ext/asyncio/__init__.py b/test/ext/asyncio/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/test/ext/asyncio/__init__.py
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
new file mode 100644
index 000000000..ec513cb64
--- /dev/null
+++ b/test/ext/asyncio/test_engine_py3k.py
@@ -0,0 +1,340 @@
+from sqlalchemy import Column
+from sqlalchemy import delete
+from sqlalchemy import exc
+from sqlalchemy import func
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import union_all
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio import exc as asyncio_exc
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.asyncio import assert_raises_message_async
+
+
+class EngineFixture(fixtures.TablesTest):
+ __requires__ = ("async_dialect",)
+
+ @testing.fixture
+ def async_engine(self):
+ return create_async_engine(testing.db.url)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "users",
+ metadata,
+ Column("user_id", Integer, primary_key=True, autoincrement=False),
+ Column("user_name", String(20)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ users = cls.tables.users
+ with connection.begin():
+ connection.execute(
+ users.insert(),
+ [
+ {"user_id": i, "user_name": "name%d" % i}
+ for i in range(1, 20)
+ ],
+ )
+
+
+class AsyncEngineTest(EngineFixture):
+ __backend__ = True
+
+ @async_test
+ async def test_connect_ctxmanager(self, async_engine):
+ async with async_engine.connect() as conn:
+ result = await conn.execute(select(1))
+ eq_(result.scalar(), 1)
+
+ @async_test
+ async def test_connect_plain(self, async_engine):
+ conn = await async_engine.connect()
+ try:
+ result = await conn.execute(select(1))
+ eq_(result.scalar(), 1)
+ finally:
+ await conn.close()
+
+ @async_test
+ async def test_connection_not_started(self, async_engine):
+
+ conn = async_engine.connect()
+ testing.assert_raises_message(
+ asyncio_exc.AsyncContextNotStarted,
+ "AsyncConnection context has not been started and "
+ "object has not been awaited.",
+ conn.begin,
+ )
+
+ @async_test
+ async def test_transaction_commit(self, async_engine):
+ users = self.tables.users
+
+ async with async_engine.begin() as conn:
+ await conn.execute(delete(users))
+
+ async with async_engine.connect() as conn:
+ eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
+
+ @async_test
+ async def test_savepoint_rollback_noctx(self, async_engine):
+ users = self.tables.users
+
+ async with async_engine.begin() as conn:
+
+ savepoint = await conn.begin_nested()
+ await conn.execute(delete(users))
+ await savepoint.rollback()
+
+ async with async_engine.connect() as conn:
+ eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
+
+ @async_test
+ async def test_savepoint_commit_noctx(self, async_engine):
+ users = self.tables.users
+
+ async with async_engine.begin() as conn:
+
+ savepoint = await conn.begin_nested()
+ await conn.execute(delete(users))
+ await savepoint.commit()
+
+ async with async_engine.connect() as conn:
+ eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
+
+ @async_test
+ async def test_transaction_rollback(self, async_engine):
+ users = self.tables.users
+
+ async with async_engine.connect() as conn:
+ trans = conn.begin()
+ await trans.start()
+ await conn.execute(delete(users))
+ await trans.rollback()
+
+ async with async_engine.connect() as conn:
+ eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
+
+ @async_test
+ async def test_conn_transaction_not_started(self, async_engine):
+
+ async with async_engine.connect() as conn:
+ trans = conn.begin()
+ await assert_raises_message_async(
+ asyncio_exc.AsyncContextNotStarted,
+ "AsyncTransaction context has not been started "
+ "and object has not been awaited.",
+ trans.rollback(),
+ )
+
+
+class AsyncResultTest(EngineFixture):
+ @testing.combinations(
+ (None,), ("scalars",), ("mappings",), argnames="filter_"
+ )
+ @async_test
+ async def test_all(self, async_engine, filter_):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(select(users))
+
+ if filter_ == "mappings":
+ result = result.mappings()
+ elif filter_ == "scalars":
+ result = result.scalars(1)
+
+ all_ = await result.all()
+ if filter_ == "mappings":
+ eq_(
+ all_,
+ [
+ {"user_id": i, "user_name": "name%d" % i}
+ for i in range(1, 20)
+ ],
+ )
+ elif filter_ == "scalars":
+ eq_(
+ all_, ["name%d" % i for i in range(1, 20)],
+ )
+ else:
+ eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
+
+ @testing.combinations(
+ (None,), ("scalars",), ("mappings",), argnames="filter_"
+ )
+ @async_test
+ async def test_aiter(self, async_engine, filter_):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(select(users))
+
+ if filter_ == "mappings":
+ result = result.mappings()
+ elif filter_ == "scalars":
+ result = result.scalars(1)
+
+ rows = []
+
+ async for row in result:
+ rows.append(row)
+
+ if filter_ == "mappings":
+ eq_(
+ rows,
+ [
+ {"user_id": i, "user_name": "name%d" % i}
+ for i in range(1, 20)
+ ],
+ )
+ elif filter_ == "scalars":
+ eq_(
+ rows, ["name%d" % i for i in range(1, 20)],
+ )
+ else:
+ eq_(rows, [(i, "name%d" % i) for i in range(1, 20)])
+
+ @testing.combinations((None,), ("mappings",), argnames="filter_")
+ @async_test
+ async def test_keys(self, async_engine, filter_):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(select(users))
+
+ if filter_ == "mappings":
+ result = result.mappings()
+
+ eq_(result.keys(), ["user_id", "user_name"])
+
+ @async_test
+ async def test_unique_all(self, async_engine):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(
+ union_all(select(users), select(users)).order_by(
+ users.c.user_id
+ )
+ )
+
+ all_ = await result.unique().all()
+ eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
+
+ @async_test
+ async def test_columns_all(self, async_engine):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(select(users))
+
+ all_ = await result.columns(1).all()
+ eq_(all_, [("name%d" % i,) for i in range(1, 20)])
+
+ @testing.combinations(
+ (None,), ("scalars",), ("mappings",), argnames="filter_"
+ )
+ @async_test
+ async def test_partitions(self, async_engine, filter_):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(select(users))
+
+ if filter_ == "mappings":
+ result = result.mappings()
+ elif filter_ == "scalars":
+ result = result.scalars(1)
+
+ check_result = []
+ async for partition in result.partitions(5):
+ check_result.append(partition)
+
+ if filter_ == "mappings":
+ eq_(
+ check_result,
+ [
+ [
+ {"user_id": i, "user_name": "name%d" % i}
+ for i in range(a, b)
+ ]
+ for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)]
+ ],
+ )
+ elif filter_ == "scalars":
+ eq_(
+ check_result,
+ [
+ ["name%d" % i for i in range(a, b)]
+ for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)]
+ ],
+ )
+ else:
+ eq_(
+ check_result,
+ [
+ [(i, "name%d" % i) for i in range(a, b)]
+ for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)]
+ ],
+ )
+
+ @testing.combinations(
+ (None,), ("scalars",), ("mappings",), argnames="filter_"
+ )
+ @async_test
+ async def test_one_success(self, async_engine, filter_):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(
+ select(users).limit(1).order_by(users.c.user_name)
+ )
+
+ if filter_ == "mappings":
+ result = result.mappings()
+ elif filter_ == "scalars":
+ result = result.scalars()
+ u1 = await result.one()
+
+ if filter_ == "mappings":
+ eq_(u1, {"user_id": 1, "user_name": "name%d" % 1})
+ elif filter_ == "scalars":
+ eq_(u1, 1)
+ else:
+ eq_(u1, (1, "name%d" % 1))
+
+ @async_test
+ async def test_one_no_result(self, async_engine):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(
+ select(users).where(users.c.user_name == "nonexistent")
+ )
+
+ async def go():
+ await result.one()
+
+ await assert_raises_message_async(
+ exc.NoResultFound,
+ "No row was found when one was required",
+ go(),
+ )
+
+ @async_test
+ async def test_one_multi_result(self, async_engine):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ result = await conn.stream(
+ select(users).where(users.c.user_name.in_(["name3", "name5"]))
+ )
+
+ async def go():
+ await result.one()
+
+ await assert_raises_message_async(
+ exc.MultipleResultsFound,
+ "Multiple rows were found when exactly one was required",
+ go(),
+ )
diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py
new file mode 100644
index 000000000..e8caaca3e
--- /dev/null
+++ b/test/ext/asyncio/test_session_py3k.py
@@ -0,0 +1,200 @@
+from sqlalchemy import exc
+from sqlalchemy import func
+from sqlalchemy import select
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.orm import selectinload
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import is_
+from ...orm import _fixtures
+
+
+class AsyncFixture(_fixtures.FixtureTest):
+ __requires__ = ("async_dialect",)
+
+ @classmethod
+ def setup_mappers(cls):
+ cls._setup_stock_mapping()
+
+ @testing.fixture
+ def async_engine(self):
+ return create_async_engine(testing.db.url)
+
+ @testing.fixture
+ def async_session(self, async_engine):
+ return AsyncSession(async_engine)
+
+
+class AsyncSessionTest(AsyncFixture):
+ def test_requires_async_engine(self, async_engine):
+ testing.assert_raises_message(
+ exc.ArgumentError,
+ "AsyncEngine expected, got Engine",
+ AsyncSession,
+ bind=async_engine.sync_engine,
+ )
+
+
+class AsyncSessionQueryTest(AsyncFixture):
+ @async_test
+ async def test_execute(self, async_session):
+ User = self.classes.User
+
+ stmt = (
+ select(User)
+ .options(selectinload(User.addresses))
+ .order_by(User.id)
+ )
+
+ result = await async_session.execute(stmt)
+ eq_(result.scalars().all(), self.static.user_address_result)
+
+ @async_test
+ async def test_stream_partitions(self, async_session):
+ User = self.classes.User
+
+ stmt = (
+ select(User)
+ .options(selectinload(User.addresses))
+ .order_by(User.id)
+ )
+
+ result = await async_session.stream(stmt)
+
+ assert_result = []
+ async for partition in result.scalars().partitions(3):
+ assert_result.append(partition)
+
+ eq_(
+ assert_result,
+ [
+ self.static.user_address_result[0:3],
+ self.static.user_address_result[3:],
+ ],
+ )
+
+
+class AsyncSessionTransactionTest(AsyncFixture):
+ run_inserts = None
+
+ @async_test
+ async def test_trans(self, async_session, async_engine):
+ async with async_engine.connect() as outer_conn:
+
+ User = self.classes.User
+
+ async with async_session.begin():
+
+ eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
+
+ u1 = User(name="u1")
+
+ async_session.add(u1)
+
+ result = await async_session.execute(select(User))
+ eq_(result.scalar(), u1)
+
+ eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
+
+ @async_test
+ async def test_commit_as_you_go(self, async_session, async_engine):
+ async with async_engine.connect() as outer_conn:
+
+ User = self.classes.User
+
+ eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
+
+ u1 = User(name="u1")
+
+ async_session.add(u1)
+
+ result = await async_session.execute(select(User))
+ eq_(result.scalar(), u1)
+
+ await async_session.commit()
+
+ eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
+
+ @async_test
+ async def test_trans_noctx(self, async_session, async_engine):
+ async with async_engine.connect() as outer_conn:
+
+ User = self.classes.User
+
+ trans = await async_session.begin()
+ try:
+ eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
+
+ u1 = User(name="u1")
+
+ async_session.add(u1)
+
+ result = await async_session.execute(select(User))
+ eq_(result.scalar(), u1)
+ finally:
+ await trans.commit()
+
+ eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
+
+ @async_test
+ async def test_flush(self, async_session):
+ User = self.classes.User
+
+ async with async_session.begin():
+ u1 = User(name="u1")
+
+ async_session.add(u1)
+
+ conn = await async_session.connection()
+
+ eq_(await conn.scalar(select(func.count(User.id))), 0)
+
+ await async_session.flush()
+
+ eq_(await conn.scalar(select(func.count(User.id))), 1)
+
+ @async_test
+ async def test_refresh(self, async_session):
+ User = self.classes.User
+
+ async with async_session.begin():
+ u1 = User(name="u1")
+
+ async_session.add(u1)
+ await async_session.flush()
+
+ conn = await async_session.connection()
+
+ await conn.execute(
+ update(User)
+ .values(name="u2")
+ .execution_options(synchronize_session=None)
+ )
+
+ eq_(u1.name, "u1")
+
+ await async_session.refresh(u1)
+
+ eq_(u1.name, "u2")
+
+ eq_(await conn.scalar(select(func.count(User.id))), 1)
+
+ @async_test
+ async def test_merge(self, async_session):
+ User = self.classes.User
+
+ async with async_session.begin():
+ u1 = User(id=1, name="u1")
+
+ async_session.add(u1)
+
+ async with async_session.begin():
+ new_u = User(id=1, name="new u1")
+
+ new_u_merged = await async_session.merge(new_u)
+
+ is_(new_u_merged, u1)
+ eq_(u1.name, "new u1")
diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py
index 75dca1c99..047ef25ae 100644
--- a/test/orm/test_update_delete.py
+++ b/test/orm/test_update_delete.py
@@ -1421,6 +1421,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
# this would work with Firebird if you do literal_column('1')
# instead
case_stmt = case([(Document.title.in_(subq), True)], else_=False)
+
s.query(Document).update(
{"flag": case_stmt}, synchronize_session=False
)
diff --git a/test/requirements.py b/test/requirements.py
index 28f955fa5..fdb7c2ff3 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -198,7 +198,7 @@ class DefaultRequirements(SuiteRequirements):
"mysql+pymysql",
"mysql+cymysql",
"mysql+mysqlconnector",
- "postgresql",
+ "postgresql+pg8000",
]
)
@@ -1163,20 +1163,6 @@ class DefaultRequirements(SuiteRequirements):
"with only four decimal places",
),
(
- "mssql+pyodbc",
- None,
- None,
- "mssql+pyodbc has FP inaccuracy even with "
- "only four decimal places ",
- ),
- (
- "mssql+pymssql",
- None,
- None,
- "mssql+pymssql has FP inaccuracy even with "
- "only four decimal places ",
- ),
- (
"postgresql+pg8000",
None,
None,
@@ -1281,6 +1267,12 @@ class DefaultRequirements(SuiteRequirements):
return only_if(check_range_types)
@property
+ def async_dialect(self):
+ """dialect makes use of await_() to invoke operations on the DBAPI."""
+
+ return only_on(["postgresql+asyncpg"])
+
+ @property
def oracle_test_dblink(self):
return skip_if(
lambda config: not config.file_config.has_option(
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index 676c46db6..aa1c0d48d 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -948,7 +948,7 @@ class PKDefaultTest(fixtures.TablesTest):
metadata,
Column(
"date_id",
- DateTime,
+ DateTime(timezone=True),
default=text("current_timestamp"),
primary_key=True,
),