diff options
-rw-r--r-- | README.unittests.rst | 8 | ||||
-rw-r--r-- | doc/build/changelog/changelog_08.rst | 9 | ||||
-rw-r--r-- | doc/build/changelog/changelog_09.rst | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/mock.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/util/__init__.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/util/compat.py | 1 | ||||
-rw-r--r-- | setup.py | 3 | ||||
-rw-r--r-- | test/aaa_profiling/test_resultset.py | 5 | ||||
-rw-r--r-- | test/base/test_events.py | 136 | ||||
-rw-r--r-- | test/dialect/postgresql/test_dialect.py | 21 | ||||
-rw-r--r-- | test/dialect/test_mxodbc.py | 77 | ||||
-rw-r--r-- | test/engine/test_ddlemit.py | 30 | ||||
-rw-r--r-- | test/engine/test_execute.py | 24 | ||||
-rw-r--r-- | test/engine/test_parseconnect.py | 64 | ||||
-rw-r--r-- | test/engine/test_pool.py | 149 | ||||
-rw-r--r-- | test/engine/test_reconnect.py | 443 |
17 files changed, 502 insertions, 497 deletions
diff --git a/README.unittests.rst b/README.unittests.rst index ae7189854..7d052cfd7 100644 --- a/README.unittests.rst +++ b/README.unittests.rst @@ -7,12 +7,18 @@ module. If running on Python 2.4, pysqlite must be installed. Unit tests are run using nose. Nose is available at:: - http://pypi.python.org/pypi/nose/ + https://pypi.python.org/pypi/nose/ SQLAlchemy implements a nose plugin that must be present when tests are run. This plugin is invoked when the test runner script provided with SQLAlchemy is used. +The test suite as of version 0.8.2 also requires the mock library. While +mock is part of the Python standard library as of 3.3, previous versions +will need to have it installed, and is available at:: + + https://pypi.python.org/pypi/mock + **NOTE:** - the nose plugin is no longer installed by setuptools as of version 0.7 ! Use "python setup.py test" or "./sqla_nose.py". diff --git a/doc/build/changelog/changelog_08.rst b/doc/build/changelog/changelog_08.rst index c0e430ad6..d83d0618e 100644 --- a/doc/build/changelog/changelog_08.rst +++ b/doc/build/changelog/changelog_08.rst @@ -7,6 +7,15 @@ :version: 0.8.2 .. change:: + :tags: requirements + + The Python `mock <https://pypi.python.org/pypi/mock>`_ library + is now required in order to run the unit test suite. While part + of the standard library as of Python 3.3, previous Python installations + will need to install this in order to run unit tests or to + use the ``sqlalchemy.testing`` package for external dialects. + + .. change:: :tags: bug, orm :tickets: 2750 diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index af3081687..466444278 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -7,6 +7,16 @@ :version: 0.9.0 .. change:: + :tags: requirements + + The Python `mock <https://pypi.python.org/pypi/mock>`_ library + is now required in order to run the unit test suite. While part + of the standard library as of Python 3.3, previous Python installations + will need to install this in order to run unit tests or to + use the ``sqlalchemy.testing`` package for external dialects. + This applies to 0.8.2 as well. + + .. change:: :tags: bug, orm :tickets: 2750 diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index d5522213d..a87829499 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -18,3 +18,5 @@ from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict crashes = skip from .config import db, requirements as requires + +from . import mock
\ No newline at end of file diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py new file mode 100644 index 000000000..650962384 --- /dev/null +++ b/lib/sqlalchemy/testing/mock.py @@ -0,0 +1,15 @@ +"""Import stub for mock library. +""" +from __future__ import absolute_import +from ..util import py33 + +if py33: + from unittest.mock import MagicMock, Mock, call +else: + try: + from mock import MagicMock, Mock, call + except ImportError: + raise ImportError( + "SQLAlchemy's test suite requires the " + "'mock' library as of 0.8.2.") + diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 687abb39a..739caefe0 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -5,7 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from .compat import callable, cmp, reduce, \ - threading, py3k, py2k, jython, pypy, cpython, win32, \ + threading, py3k, py33, py2k, jython, pypy, cpython, win32, \ pickle, dottedgetter, parse_qsl, namedtuple, next, WeakSet, reraise, \ raise_from_cause, text_type, string_types, int_types, binary_type, \ quote_plus, with_metaclass, print_, itertools_filterfalse, u, ue, b,\ diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index fea22873c..d866534ab 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -13,6 +13,7 @@ try: except ImportError: import dummy_threading as threading +py33 = sys.version_info >= (3, 3) py32 = sys.version_info >= (3, 2) py3k = sys.version_info >= (3, 0) py2k = sys.version_info < (3, 0) @@ -119,8 +119,7 @@ def run_setup(with_cext): package_dir={'': 'lib'}, license="MIT License", cmdclass=cmdclass, - - tests_require=['nose >= 0.11'], + tests_require=['nose >= 0.11', 'mock'], test_suite="sqla_nose", long_description=readme, classifiers=[ diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 27e60410d..bbd8c4dba 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -3,6 +3,9 @@ from sqlalchemy.testing import fixtures, AssertsExecutionResults, profiling from sqlalchemy import testing from sqlalchemy.testing import eq_ from sqlalchemy.util import u +from sqlalchemy.engine.result import RowProxy +import sys + NUM_FIELDS = 10 NUM_RECORDS = 1000 @@ -80,7 +83,6 @@ class RowProxyTest(fixtures.TestBase): __requires__ = 'cpython', def _rowproxy_fixture(self, keys, processors, row): - from sqlalchemy.engine.result import RowProxy class MockMeta(object): def __init__(self): pass @@ -96,7 +98,6 @@ class RowProxyTest(fixtures.TestBase): return RowProxy(metadata, row, processors, keymap) def _test_getitem_value_refcounts(self, seq_factory): - import sys col1, col2 = object(), object() def proc1(value): return value diff --git a/test/base/test_events.py b/test/base/test_events.py index 7cfb5fa7d..20bfa62ff 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -5,6 +5,8 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, \ from sqlalchemy import event, exc from sqlalchemy.testing import fixtures from sqlalchemy.testing.util import gc_collect +from sqlalchemy.testing.mock import Mock, call + class EventsTest(fixtures.TestBase): """Test class- and instance-level event registration.""" @@ -190,7 +192,7 @@ class ClsLevelListenTest(fixtures.TestBase): def test_lis_subcalss_lis(self): @event.listens_for(self.TargetOne, "event_one") def handler1(x, y): - print('handler1') + pass class SubTarget(self.TargetOne): pass @@ -207,7 +209,7 @@ class ClsLevelListenTest(fixtures.TestBase): def test_lis_multisub_lis(self): @event.listens_for(self.TargetOne, "event_one") def handler1(x, y): - print('handler1') + pass class SubTarget(self.TargetOne): pass @@ -411,12 +413,8 @@ class ListenOverrideTest(fixtures.TestBase): event._remove_dispatcher(self.Target.__dict__['dispatch'].events) def test_listen_override(self): - result = [] - def listen_one(x): - result.append(x) - - def listen_two(x, y): - result.append((x, y)) + listen_one = Mock() + listen_two = Mock() event.listen(self.Target, "event_one", listen_one, add=True) event.listen(self.Target, "event_one", listen_two) @@ -425,10 +423,13 @@ class ListenOverrideTest(fixtures.TestBase): t1.dispatch.event_one(5, 7) t1.dispatch.event_one(10, 5) - eq_(result, - [ - 12, (5, 7), 15, (10, 5) - ] + eq_( + listen_one.mock_calls, + [call(12), call(15)] + ) + eq_( + listen_two.mock_calls, + [call(5, 7), call(10, 5)] ) class PropagateTest(fixtures.TestBase): @@ -446,12 +447,8 @@ class PropagateTest(fixtures.TestBase): def test_propagate(self): - result = [] - def listen_one(target, arg): - result.append((target, arg)) - - def listen_two(target, arg): - result.append((target, arg)) + listen_one = Mock() + listen_two = Mock() t1 = self.Target() @@ -464,7 +461,15 @@ class PropagateTest(fixtures.TestBase): t2.dispatch.event_one(t2, 1) t2.dispatch.event_two(t2, 2) - eq_(result, [(t2, 1)]) + + eq_( + listen_one.mock_calls, + [call(t2, 1)] + ) + eq_( + listen_two.mock_calls, + [] + ) class JoinTest(fixtures.TestBase): def setUp(self): @@ -497,12 +502,6 @@ class JoinTest(fixtures.TestBase): if 'dispatch' in cls.__dict__: event._remove_dispatcher(cls.__dict__['dispatch'].events) - def _listener(self): - canary = [] - def listen(target, arg): - canary.append((target, arg)) - return listen, canary - def test_neither(self): element = self.TargetFactory().create() element.run_event(1) @@ -510,22 +509,22 @@ class JoinTest(fixtures.TestBase): element.run_event(3) def test_parent_class_only(self): - _listener, canary = self._listener() + l1 = Mock() - event.listen(self.TargetFactory, "event_one", _listener) + event.listen(self.TargetFactory, "event_one", l1) element = self.TargetFactory().create() element.run_event(1) element.run_event(2) element.run_event(3) eq_( - canary, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_class_child_class(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetFactory, "event_one", l1) event.listen(self.TargetElement, "event_one", l2) @@ -535,17 +534,17 @@ class JoinTest(fixtures.TestBase): element.run_event(2) element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_class_child_instance_apply_after(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetFactory, "event_one", l1) element = self.TargetFactory().create() @@ -557,17 +556,17 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 2), (element, 3)] + l2.mock_calls, + [call(element, 2), call(element, 3)] ) def test_parent_class_child_instance_apply_before(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetFactory, "event_one", l1) element = self.TargetFactory().create() @@ -579,17 +578,17 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_instance_child_class_apply_before(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetElement, "event_one", l2) @@ -603,17 +602,18 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) + def test_parent_instance_child_class_apply_after(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetElement, "event_one", l2) @@ -632,18 +632,16 @@ class JoinTest(fixtures.TestBase): # this can be changed to be "live" at the cost # of performance. eq_( - c1, - [] - #(element, 2), (element, 3)] + l1.mock_calls, [] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_instance_child_instance_apply_before(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() factory = self.TargetFactory() event.listen(factory, "event_one", l1) @@ -656,16 +654,16 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_events_child_no_events(self): - l1, c1 = self._listener() + l1 = Mock() factory = self.TargetFactory() event.listen(self.TargetElement, "event_one", l1) @@ -676,6 +674,6 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 86ce91dc9..1fc239cb7 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -16,6 +16,7 @@ from sqlalchemy import exc, schema from sqlalchemy.dialects.postgresql import base as postgresql import logging import logging.handlers +from sqlalchemy.testing.mock import Mock class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): @@ -37,18 +38,12 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): 'The JDBC driver handles the version parsing') def test_version_parsing(self): - - class MockConn(object): - - def __init__(self, res): - self.res = res - - def execute(self, str): - return self - - def scalar(self): - return self.res - + def mock_conn(res): + return Mock( + execute=Mock( + return_value=Mock(scalar=Mock(return_value=res)) + ) + ) for string, version in \ [('PostgreSQL 8.3.8 on i686-redhat-linux-gnu, compiled by ' @@ -59,7 +54,7 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): ('EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, ' 'compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), ' '64-bit', (9, 1, 2))]: - eq_(testing.db.dialect._get_server_version_info(MockConn(string)), + eq_(testing.db.dialect._get_server_version_info(mock_conn(string)), version) @testing.only_on('postgresql+psycopg2', 'psycopg2-specific feature') diff --git a/test/dialect/test_mxodbc.py b/test/dialect/test_mxodbc.py index 32cad4168..e46de9149 100644 --- a/test/dialect/test_mxodbc.py +++ b/test/dialect/test_mxodbc.py @@ -2,75 +2,48 @@ from sqlalchemy import * from sqlalchemy.testing import eq_ from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures - -# TODO: we should probably build mock bases for -# these to share with test_reconnect, test_parseconnect -class MockDBAPI(object): - paramstyle = 'qmark' - def __init__(self): - self.log = [] - def connect(self, *args, **kwargs): - return MockConnection(self) - -class MockConnection(object): - def __init__(self, parent): - self.parent = parent - def cursor(self): - return MockCursor(self) - def close(self): - pass - def rollback(self): - pass - def commit(self): - pass - -class MockCursor(object): - description = None - rowcount = None - def __init__(self, parent): - self.parent = parent - def execute(self, *args, **kwargs): - if kwargs.get('direct', False): - self.executedirect() - else: - self.parent.parent.log.append('execute') - def executedirect(self, *args, **kwargs): - self.parent.parent.log.append('executedirect') - def close(self): - pass +from sqlalchemy.testing.mock import Mock + +def mock_dbapi(): + return Mock(paramstyle='qmark', + connect=Mock( + return_value=Mock( + cursor=Mock( + return_value=Mock( + description=None, + rowcount=None) + ) + ) + ) + ) class MxODBCTest(fixtures.TestBase): def test_native_odbc_execute(self): t1 = Table('t1', MetaData(), Column('c1', Integer)) - dbapi = MockDBAPI() + dbapi = mock_dbapi() + engine = engines.testing_engine('mssql+mxodbc://localhost', options={'module': dbapi, '_initialize': False}) conn = engine.connect() # crud: uses execute - conn.execute(t1.insert().values(c1='foo')) conn.execute(t1.delete().where(t1.c.c1 == 'foo')) - conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar' - )) + conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar')) # select: uses executedirect - conn.execute(t1.select()) # manual flagging - conn.execution_options(native_odbc_execute=True).\ execute(t1.select()) conn.execution_options(native_odbc_execute=False).\ - execute(t1.insert().values(c1='foo' - )) - eq_(dbapi.log, [ - 'executedirect', - 'executedirect', - 'executedirect', - 'executedirect', - 'execute', - 'executedirect', - ]) + execute(t1.insert().values(c1='foo')) + + eq_( + [c[2] for c in + dbapi.connect.return_value.cursor.return_value.execute.mock_calls], + [{'direct': True}, {'direct': True}, {'direct': True}, + {'direct': True}, {'direct': False}, {'direct': True}] + ) diff --git a/test/engine/test_ddlemit.py b/test/engine/test_ddlemit.py index deaf09cf7..e773d0ced 100644 --- a/test/engine/test_ddlemit.py +++ b/test/engine/test_ddlemit.py @@ -3,28 +3,19 @@ from sqlalchemy.engine.ddl import SchemaGenerator, SchemaDropper from sqlalchemy.engine import default from sqlalchemy import MetaData, Table, Column, Integer, Sequence from sqlalchemy import schema +from sqlalchemy.testing.mock import Mock class EmitDDLTest(fixtures.TestBase): def _mock_connection(self, item_exists): - _canary = [] + def has_item(connection, name, schema): + return item_exists(name) - class MockDialect(default.DefaultDialect): - supports_sequences = True - - def has_table(self, connection, name, schema): - return item_exists(name) - - def has_sequence(self, connection, name, schema): - return item_exists(name) - - class MockConnection(object): - dialect = MockDialect() - canary = _canary - - def execute(self, item): - _canary.append(item) - - return MockConnection() + return Mock(dialect=Mock( + supports_sequences=True, + has_table=Mock(side_effect=has_item), + has_sequence=Mock(side_effect=has_item) + ) + ) def _mock_create_fixture(self, checkfirst, tables, item_exists=lambda item: False): @@ -176,7 +167,8 @@ class EmitDDLTest(fixtures.TestBase): def _assert_ddl(self, ddl_cls, elements, generator, argument): generator.traverse_single(argument) - for c in generator.connection.canary: + for call_ in generator.connection.execute.mock_calls: + c = call_[1][0] assert isinstance(c, ddl_cls) assert c.element in elements, "element %r was not expected"\ % c.element diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 1c577730b..9795e4c10 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -19,6 +19,8 @@ from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam from sqlalchemy.engine import result as _result, default from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.testing import fixtures +from sqlalchemy.testing.mock import Mock, call + users, metadata, users_autoinc = None, None, None class ExecuteTest(fixtures.TestBase): @@ -455,20 +457,22 @@ class ConvenienceExecuteTest(fixtures.TablesTest): def test_transaction_engine_ctx_begin_fails(self): engine = engines.testing_engine() - class MockConnection(Connection): - closed = False - def begin(self): - raise Exception("boom") - - def close(self): - MockConnection.closed = True - engine._connection_cls = MockConnection - fn = self._trans_fn() + + mock_connection = Mock( + return_value=Mock( + begin=Mock(side_effect=Exception("boom")) + ) + ) + engine._connection_cls = mock_connection assert_raises( Exception, engine.begin ) - assert MockConnection.closed + + eq_( + mock_connection.return_value.close.mock_calls, + [call()] + ) def test_transaction_engine_ctx_rollback(self): fn = self._trans_rollback_fn() diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 73bdc76c4..106bd0782 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -7,6 +7,8 @@ from sqlalchemy.engine.default import DefaultDialect import sqlalchemy as tsa from sqlalchemy.testing import fixtures from sqlalchemy import testing +from sqlalchemy.testing.mock import Mock + class ParseConnectTest(fixtures.TestBase): def test_rfc1738(self): @@ -250,20 +252,17 @@ pool_timeout=10 every backend. """ - # pretend pysqlite throws the - # "Cannot operate on a closed database." error - # on connect. IRL we'd be getting Oracle's "shutdown in progress" e = create_engine('sqlite://') sqlite3 = e.dialect.dbapi - class ThrowOnConnect(MockDBAPI): - dbapi = sqlite3 - Error = sqlite3.Error - ProgrammingError = sqlite3.ProgrammingError - def connect(self, *args, **kw): - raise sqlite3.ProgrammingError("Cannot operate on a closed database.") + + dbapi = MockDBAPI() + dbapi.Error = sqlite3.Error, + dbapi.ProgrammingError = sqlite3.ProgrammingError + dbapi.connect = Mock(side_effect=sqlite3.ProgrammingError( + "Cannot operate on a closed database.")) try: - create_engine('sqlite://', module=ThrowOnConnect()).connect() + create_engine('sqlite://', module=dbapi).connect() assert False except tsa.exc.DBAPIError as de: assert de.connection_invalidated @@ -354,36 +353,23 @@ class MockDialect(DefaultDialect): def dbapi(cls, **kw): return MockDBAPI() -class MockDBAPI(object): - version_info = sqlite_version_info = 99, 9, 9 - sqlite_version = '99.9.9' - - def __init__(self, **kwargs): - self.kwargs = kwargs - self.paramstyle = 'named' - - def connect(self, *args, **kwargs): - for k in self.kwargs: +def MockDBAPI(**assert_kwargs): + connection = Mock(get_server_version_info=Mock(return_value='5.0')) + def connect(*args, **kwargs): + for k in assert_kwargs: assert k in kwargs, 'key %s not present in dictionary' % k - assert kwargs[k] == self.kwargs[k], \ - 'value %s does not match %s' % (kwargs[k], - self.kwargs[k]) - return MockConnection() - - -class MockConnection(object): - def get_server_info(self): - return '5.0' - - def close(self): - pass - - def cursor(self): - return MockCursor() - -class MockCursor(object): - def close(self): - pass + eq_( + kwargs[k], assert_kwargs[k] + ) + return connection + + return Mock( + sqlite_version_info=(99, 9, 9,), + version_info=(99, 9, 9,), + sqlite_version='99.9.9', + paramstyle='named', + connect=Mock(side_effect=connect) + ) mock_dbapi = MockDBAPI() mock_sqlite_dbapi = msd = MockDBAPI() diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 583978465..981df6dd0 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -4,37 +4,21 @@ from sqlalchemy import pool, select, event import sqlalchemy as tsa from sqlalchemy import testing from sqlalchemy.testing.util import gc_collect, lazy_gc -from sqlalchemy.testing import eq_, assert_raises +from sqlalchemy.testing import eq_, assert_raises, is_not_ from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing import fixtures -mcid = 1 -class MockDBAPI(object): - throw_error = False - def connect(self, *args, **kwargs): - if self.throw_error: - raise Exception("couldnt connect !") - delay = kwargs.pop('delay', 0) - if delay: - time.sleep(delay) - return MockConnection() -class MockConnection(object): - closed = False - def __init__(self): - global mcid - self.id = mcid - mcid += 1 - def close(self): - self.closed = True - def rollback(self): - pass - def cursor(self): - return MockCursor() -class MockCursor(object): - def execute(self, *args, **kw): - pass - def close(self): - pass +from sqlalchemy.testing.mock import Mock, call + +def MockDBAPI(): + def cursor(): + while True: + yield Mock() + def connect(): + while True: + yield Mock(cursor=Mock(side_effect=cursor())) + + return Mock(connect=Mock(side_effect=connect())) class PoolTestBase(fixtures.TestBase): def setup(self): @@ -71,11 +55,9 @@ class PoolTest(PoolTestBase): assert c4 is not c5 def test_manager_with_key(self): - class NoKws(object): - def connect(self, arg): - return MockConnection() - manager = pool.manage(NoKws(), use_threadlocal=True) + dbapi = MockDBAPI() + manager = pool.manage(dbapi, use_threadlocal=True) c1 = manager.connect('foo.db', sa_pool_key="a") c2 = manager.connect('foo.db', sa_pool_key="b") @@ -83,9 +65,14 @@ class PoolTest(PoolTestBase): assert c1.cursor() is not None assert c1 is not c2 - assert c1 is c3 - + assert c1 is c3 + eq_(dbapi.connect.mock_calls, + [ + call("foo.db"), + call("foo.db"), + ] + ) def test_bad_args(self): @@ -127,7 +114,7 @@ class PoolTest(PoolTestBase): p = cls(creator=mock_dbapi.connect) conn = p.connect() conn.close() - mock_dbapi.throw_error = True + mock_dbapi.connect.side_effect = Exception("error!") p.dispose() p.recreate() @@ -211,9 +198,9 @@ class PoolTest(PoolTestBase): self.assert_('foo2' in c.info) c2 = p.connect() - self.assert_(c.connection is not c2.connection) - self.assert_(not c2.info) - self.assert_('foo2' in c.info) + is_not_(c.connection, c2.connection) + assert not c2.info + assert 'foo2' in c.info class PoolDialectTest(PoolTestBase): @@ -945,19 +932,24 @@ class QueuePoolTest(PoolTestBase): def test_dispose_closes_pooled(self): dbapi = MockDBAPI() - def creator(): - return dbapi.connect() - p = pool.QueuePool(creator=creator, + p = pool.QueuePool(creator=dbapi.connect, pool_size=2, timeout=None, max_overflow=0) c1 = p.connect() c2 = p.connect() - conns = [c1.connection, c2.connection] + c1_con = c1.connection + c2_con = c2.connection + c1.close() - eq_([c.closed for c in conns], [False, False]) + + eq_(c1_con.close.call_count, 0) + eq_(c2_con.close.call_count, 0) + p.dispose() - eq_([c.closed for c in conns], [True, False]) + + eq_(c1_con.close.call_count, 1) + eq_(c2_con.close.call_count, 0) # currently, if a ConnectionFairy is closed # after the pool has been disposed, there's no @@ -965,11 +957,12 @@ class QueuePoolTest(PoolTestBase): # immediately - it just gets returned to the # pool normally... c2.close() - eq_([c.closed for c in conns], [True, False]) + eq_(c1_con.close.call_count, 1) + eq_(c2_con.close.call_count, 0) # ...and that's the one we'll get back next. c3 = p.connect() - assert c3.connection is conns[1] + assert c3.connection is c2_con def test_no_overflow(self): self._test_overflow(40, 0) @@ -1010,12 +1003,21 @@ class QueuePoolTest(PoolTestBase): return c for j in range(5): + # open 4 conns at a time. each time this + # will yield two pooled connections + two + # overflow connections. conns = [_conn() for i in range(4)] for c in conns: c.close() - still_opened = len([c for c in strong_refs if not c.closed]) - eq_(still_opened, 2) + # doing that for a total of 5 times yields + # ten overflow connections closed plus the + # two pooled connections unclosed. + + eq_( + set([c.close.call_count for c in strong_refs]), + set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0]) + ) @testing.requires.predictable_gc def test_weakref_kaboom(self): @@ -1108,18 +1110,30 @@ class QueuePoolTest(PoolTestBase): dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0) c1 = p.connect() c1.detach() - c_id = c1.connection.id - c2 = p.connect() - assert c2.connection.id != c1.connection.id - dbapi.raise_error = True - c2.invalidate() - c2 = None c2 = p.connect() - assert c2.connection.id != c1.connection.id - con = c1.connection - assert not con.closed + eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")]) + + c1_con = c1.connection + assert c1_con is not None + eq_(c1_con.close.call_count, 0) c1.close() - assert con.closed + eq_(c1_con.close.call_count, 1) + + def test_detach_via_invalidate(self): + dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0) + + c1 = p.connect() + c1_con = c1.connection + c1.invalidate() + assert c1.connection is None + eq_(c1_con.close.call_count, 1) + + c2 = p.connect() + assert c2.connection is not c1_con + c2_con = c2.connection + + c2.close() + eq_(c2_con.close.call_count, 0) def test_threadfairy(self): p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True) @@ -1141,8 +1155,13 @@ class SingletonThreadPoolTest(PoolTestBase): been called.""" dbapi = MockDBAPI() - p = pool.SingletonThreadPool(creator=dbapi.connect, - pool_size=3) + + lock = threading.Lock() + def creator(): + # the mock iterator isn't threadsafe... + with lock: + return dbapi.connect() + p = pool.SingletonThreadPool(creator=creator, pool_size=3) if strong_refs: sr = set() @@ -1172,7 +1191,7 @@ class SingletonThreadPoolTest(PoolTestBase): assert len(p._all_conns) == 3 if strong_refs: - still_opened = len([c for c in sr if not c.closed]) + still_opened = len([c for c in sr if not c.close.call_count]) eq_(still_opened, 3) class AssertionPoolTest(PoolTestBase): @@ -1198,17 +1217,19 @@ class NullPoolTest(PoolTestBase): dbapi = MockDBAPI() p = pool.NullPool(creator=lambda: dbapi.connect('foo.db')) c1 = p.connect() - c_id = c1.connection.id + c1.close() c1 = None c1 = p.connect() - dbapi.raise_error = True c1.invalidate() c1 = None c1 = p.connect() - assert c1.connection.id != c_id + dbapi.connect.assert_has_calls([ + call('foo.db'), + call('foo.db')], + any_order=True) class StaticPoolTest(PoolTestBase): diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index ee3ff1459..86003bec6 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,7 +1,6 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message import time -import weakref -from sqlalchemy import select, MetaData, Integer, String, pool, create_engine +from sqlalchemy import select, MetaData, Integer, String, create_engine, pool from sqlalchemy.testing.schema import Table, Column import sqlalchemy as tsa from sqlalchemy import testing @@ -10,6 +9,8 @@ from sqlalchemy.testing.util import gc_collect from sqlalchemy import exc, util from sqlalchemy.testing import fixtures from sqlalchemy.testing.engines import testing_engine +from sqlalchemy.testing import is_not_ +from sqlalchemy.testing.mock import Mock, call class MockError(Exception): pass @@ -17,93 +18,103 @@ class MockError(Exception): class MockDisconnect(MockError): pass -class MockDBAPI(object): - def __init__(self): - self.paramstyle = 'named' - self.connections = weakref.WeakKeyDictionary() - def connect(self, *args, **kwargs): - return MockConnection(self) - def shutdown(self, explode='execute'): - for c in self.connections: - c.explode = explode - Error = MockError - -class MockConnection(object): - def __init__(self, dbapi): - dbapi.connections[self] = True - self.explode = "" - def rollback(self): - if self.explode == 'rollback': +def mock_connection(): + def mock_cursor(): + def execute(*args, **kwargs): + if conn.explode == 'execute': + raise MockDisconnect("Lost the DB connection on execute") + elif conn.explode in ('execute_no_disconnect', ): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif conn.explode in ('rollback', 'rollback_no_disconnect'): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif args and "SELECT" in args[0]: + cursor.description = [('foo', None, None, None, None, None)] + else: + return + + def close(): + cursor.fetchall = cursor.fetchone = \ + Mock(side_effect=MockError("cursor closed")) + cursor = Mock( + execute=Mock(side_effect=execute), + close=Mock(side_effect=close) + ) + return cursor + + def cursor(): + while True: + yield mock_cursor() + + def rollback(): + if conn.explode == 'rollback': raise MockDisconnect("Lost the DB connection on rollback") - if self.explode == 'rollback_no_disconnect': + if conn.explode == 'rollback_no_disconnect': raise MockError( "something broke on rollback but we didn't lose the connection") else: return - def commit(self): - pass - def cursor(self): - return MockCursor(self) - def close(self): - pass - -class MockCursor(object): - def __init__(self, parent): - self.explode = parent.explode - self.description = () - self.closed = False - def execute(self, *args, **kwargs): - if self.explode == 'execute': - raise MockDisconnect("Lost the DB connection on execute") - elif self.explode in ('execute_no_disconnect', ): - raise MockError( - "something broke on execute but we didn't lose the connection") - elif self.explode in ('rollback', 'rollback_no_disconnect'): - raise MockError( - "something broke on execute but we didn't lose the connection") - elif args and "select" in args[0]: - self.description = [('foo', None, None, None, None, None)] - else: - return - def fetchall(self): - if self.closed: - raise MockError("cursor closed") - return [] - def fetchone(self): - if self.closed: - raise MockError("cursor closed") - return None - def close(self): - self.closed = True - -db, dbapi = None, None + + conn = Mock( + rollback=Mock(side_effect=rollback), + cursor=Mock(side_effect=cursor()) + ) + return conn + +def MockDBAPI(): + connections = [] + def connect(): + while True: + conn = mock_connection() + connections.append(conn) + yield conn + + def shutdown(explode='execute'): + for c in connections: + c.explode = explode + + def dispose(): + for c in connections: + c.explode = None + connections[:] = [] + + return Mock( + connect=Mock(side_effect=connect()), + shutdown=Mock(side_effect=shutdown), + dispose=Mock(side_effect=dispose), + paramstyle='named', + connections=connections, + Error=MockError + ) + + class MockReconnectTest(fixtures.TestBase): def setup(self): - global db, dbapi - dbapi = MockDBAPI() + self.dbapi = MockDBAPI() - # note - using straight create_engine here - # since we are testing gc - db = create_engine( + self.db = testing_engine( 'postgresql://foo:bar@localhost/test', - module=dbapi, _initialize=False) + options=dict(module=self.dbapi, _initialize=False)) + self.mock_connect = call(host='localhost', password='bar', + user='foo', database='test') # monkeypatch disconnect checker - db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect) + self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect) def teardown(self): - db.dispose() + self.dbapi.dispose() def test_reconnect(self): """test that an 'is_disconnect' condition will invalidate the connection, and additionally dispose the previous connection pool and recreate.""" - pid = id(db.pool) + db_pool = self.db.pool # make a connection - conn = db.connect() + conn = self.db.connect() # connection works @@ -112,21 +123,20 @@ class MockReconnectTest(fixtures.TestBase): # create a second connection within the pool, which we'll ensure # also goes away - conn2 = db.connect() + conn2 = self.db.connect() conn2.close() # two connections opened total now - assert len(dbapi.connections) == 2 + assert len(self.dbapi.connections) == 2 # set it to fail - dbapi.shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + self.dbapi.shutdown() + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) # assert was invalidated @@ -136,31 +146,38 @@ class MockReconnectTest(fixtures.TestBase): # close shouldnt break conn.close() - assert id(db.pool) != pid + is_not_(self.db.pool, db_pool) # ensure all connections closed (pool was recycled) - gc_collect() - assert len(dbapi.connections) == 0 - conn = db.connect() + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], [call()]] + ) + + conn = self.db.connect() conn.execute(select([1])) conn.close() - assert len(dbapi.connections) == 1 + + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], [call()], []] + ) def test_invalidate_trans(self): - conn = db.connect() + conn = self.db.connect() trans = conn.begin() - dbapi.shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + self.dbapi.shutdown() - # assert was invalidated + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) - gc_collect() - assert len(dbapi.connections) == 0 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()]] + ) assert not conn.closed assert conn.invalidated assert trans.is_active @@ -170,28 +187,35 @@ class MockReconnectTest(fixtures.TestBase): conn.execute, select([1]) ) assert trans.is_active - try: - trans.commit() - assert False - except tsa.exc.InvalidRequestError as e: - assert str(e) \ - == "Can't reconnect until invalid transaction is "\ - "rolled back" + + assert_raises_message( + tsa.exc.InvalidRequestError, + "Can't reconnect until invalid transaction is " + "rolled back", + trans.commit + ) + assert trans.is_active trans.rollback() assert not trans.is_active conn.execute(select([1])) assert not conn.invalidated - assert len(dbapi.connections) == 1 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], []] + ) def test_conn_reusable(self): - conn = db.connect() + conn = self.db.connect() conn.execute(select([1])) - assert len(dbapi.connections) == 1 + eq_( + self.dbapi.connect.mock_calls, + [self.mock_connect] + ) - dbapi.shutdown() + self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, @@ -201,19 +225,24 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.closed assert conn.invalidated - # ensure all connections closed (pool was recycled) - gc_collect() - assert len(dbapi.connections) == 0 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()]] + ) # test reconnects conn.execute(select([1])) assert not conn.invalidated - assert len(dbapi.connections) == 1 + + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], []] + ) def test_invalidated_close(self): - conn = db.connect() + conn = self.db.connect() - dbapi.shutdown() + self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, @@ -230,9 +259,9 @@ class MockReconnectTest(fixtures.TestBase): ) def test_noreconnect_execute_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("execute_no_disconnect") + self.dbapi.shutdown("execute_no_disconnect") # raises error assert_raises_message( @@ -245,9 +274,9 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.invalidated def test_noreconnect_rollback_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("rollback_no_disconnect") + self.dbapi.shutdown("rollback_no_disconnect") # raises error assert_raises_message( @@ -266,13 +295,13 @@ class MockReconnectTest(fixtures.TestBase): ) def test_reconnect_on_reentrant(self): - conn = db.connect() + conn = self.db.connect() conn.execute(select([1])) - assert len(dbapi.connections) == 1 + assert len(self.dbapi.connections) == 1 - dbapi.shutdown("rollback") + self.dbapi.shutdown("rollback") # raises error assert_raises_message( @@ -285,9 +314,9 @@ class MockReconnectTest(fixtures.TestBase): assert conn.invalidated def test_reconnect_on_reentrant_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("rollback") + self.dbapi.shutdown("rollback") # raises error assert_raises_message( @@ -306,10 +335,11 @@ class MockReconnectTest(fixtures.TestBase): ) def test_check_disconnect_no_cursor(self): - conn = db.connect() - result = conn.execute("select 1") + conn = self.db.connect() + result = conn.execute(select([1])) result.cursor.close() conn.close() + assert_raises_message( tsa.exc.DBAPIError, "cursor closed", @@ -319,60 +349,59 @@ class MockReconnectTest(fixtures.TestBase): class CursorErrTest(fixtures.TestBase): def setup(self): - global db, dbapi - - class MDBAPI(MockDBAPI): - def connect(self, *args, **kwargs): - return MConn(self) - - class MConn(MockConnection): - def cursor(self): - return MCursor(self) + def MockDBAPI(): + def cursor(): + while True: + yield Mock( + description=[], + close=Mock(side_effect=Exception("explode"))) + def connect(): + while True: + yield Mock(cursor=Mock(side_effect=cursor())) + + return Mock(connect=Mock(side_effect=connect())) - class MCursor(MockCursor): - def close(self): - raise Exception("explode") - - dbapi = MDBAPI() - - db = testing_engine( + dbapi = MockDBAPI() + self.db = testing_engine( 'postgresql://foo:bar@localhost/test', options=dict(module=dbapi, _initialize=False)) def test_cursor_explode(self): - conn = db.connect() + conn = self.db.connect() result = conn.execute("select foo") result.close() conn.close() def teardown(self): - db.dispose() + self.db.dispose() + + +def _assert_invalidated(fn, *args): + try: + fn(*args) + assert False + except tsa.exc.DBAPIError as e: + if not e.connection_invalidated: + raise -engine = None class RealReconnectTest(fixtures.TestBase): def setup(self): - global engine - engine = engines.reconnecting_engine() + self.engine = engines.reconnecting_engine() def teardown(self): - engine.dispose() + self.engine.dispose() @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_reconnect(self): - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() + self.engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated @@ -382,13 +411,9 @@ class RealReconnectTest(fixtures.TestBase): assert not conn.invalidated # one more time - engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select([1])) + assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated @@ -396,30 +421,22 @@ class RealReconnectTest(fixtures.TestBase): conn.close() def test_multiple_invalidate(self): - c1 = engine.connect() - c2 = engine.connect() + c1 = self.engine.connect() + c2 = self.engine.connect() eq_(c1.execute(select([1])).scalar(), 1) - p1 = engine.pool - engine.test_shutdown() + p1 = self.engine.pool + self.engine.test_shutdown() - try: - c1.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - assert e.connection_invalidated + _assert_invalidated(c1.execute, select([1])) - p2 = engine.pool + p2 = self.engine.pool - try: - c2.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - assert e.connection_invalidated + _assert_invalidated(c2.execute, select([1])) # pool isn't replaced - assert engine.pool is p2 + assert self.engine.pool is p2 def test_ensure_is_disconnect_gets_connection(self): @@ -430,37 +447,37 @@ class RealReconnectTest(fixtures.TestBase): # though MySQLdb we get a non-working cursor. # assert cursor is None - engine.dialect.is_disconnect = is_disconnect - conn = engine.connect() - engine.test_shutdown() + self.engine.dialect.is_disconnect = is_disconnect + conn = self.engine.connect() + self.engine.test_shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) def test_rollback_on_invalid_plain(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() conn.invalidate() trans.rollback() @testing.requires.two_phase_transactions def test_rollback_on_invalid_twophase(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin_twophase() conn.invalidate() trans.rollback() @testing.requires.savepoints def test_rollback_on_invalid_savepoint(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() trans2 = conn.begin_nested() conn.invalidate() trans2.rollback() def test_invalidate_twice(self): - conn = engine.connect() + conn = self.engine.connect() conn.invalidate() conn.invalidate() @@ -503,12 +520,7 @@ class RealReconnectTest(fixtures.TestBase): eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) @@ -517,37 +529,27 @@ class RealReconnectTest(fixtures.TestBase): @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_close(self): - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() + self.engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) conn.close() - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_with_transaction(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated assert trans.is_active @@ -558,13 +560,11 @@ class RealReconnectTest(fixtures.TestBase): conn.execute, select([1]) ) assert trans.is_active - try: - trans.commit() - assert False - except tsa.exc.InvalidRequestError as e: - assert str(e) \ - == "Can't reconnect until invalid transaction is "\ - "rolled back" + assert_raises_message( + tsa.exc.InvalidRequestError, + "Can't reconnect until invalid transaction is rolled back", + trans.commit + ) assert trans.is_active trans.rollback() assert not trans.is_active @@ -602,23 +602,21 @@ class RecycleTest(fixtures.TestBase): eq_(conn.execute(select([1])).scalar(), 1) conn.close() -meta, table, engine = None, None, None class InvalidateDuringResultTest(fixtures.TestBase): def setup(self): - global meta, table, engine - engine = engines.reconnecting_engine() - meta = MetaData(engine) - table = Table('sometable', meta, + self.engine = engines.reconnecting_engine() + self.meta = MetaData(self.engine) + table = Table('sometable', self.meta, Column('id', Integer, primary_key=True), Column('name', String(50))) - meta.create_all() + self.meta.create_all() table.insert().execute( - [{'id':i, 'name':'row %d' % i} for i in range(1, 100)] + [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)] ) def teardown(self): - meta.drop_all() - engine.dispose() + self.meta.drop_all() + self.engine.dispose() @testing.fails_if([ '+mysqlconnector', '+mysqldb', @@ -628,16 +626,11 @@ class InvalidateDuringResultTest(fixtures.TestBase): @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_invalidate_on_results(self): - conn = engine.connect() + conn = self.engine.connect() result = conn.execute('select * from sometable') for x in range(20): result.fetchone() - engine.test_shutdown() - try: - print('ghost result: %r' % result.fetchone()) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(result.fetchone) assert conn.invalidated |