diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-06-30 18:35:12 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-06-30 18:35:12 -0400 |
commit | b38a76cd1d47cd6b8f1abef30ad7c3aeaa27d537 (patch) | |
tree | 7af1dba9e242c77a248cb2194434aa9bf3ca49b7 | |
parent | 715d6cf3d10a71acd7726b7e00c3ff40b4559bc7 (diff) | |
download | sqlalchemy-b38a76cd1d47cd6b8f1abef30ad7c3aeaa27d537.tar.gz |
- replace most explicitly-named test objects called "Mock..." with
actual mock objects from the mock library. I'd like to use mock
for new tests so we might as well use it in obvious places.
- use unittest.mock in py3.3
- changelog
- add a note to README.unittests
- add tests_require in setup.py
- have tests import from sqlalchemy.testing.mock
- apply usage of mock to one of the event tests. we can be using
this approach all over the place.
-rw-r--r-- | README.unittests.rst | 8 | ||||
-rw-r--r-- | doc/build/changelog/changelog_08.rst | 9 | ||||
-rw-r--r-- | doc/build/changelog/changelog_09.rst | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/mock.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/util/__init__.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/util/compat.py | 1 | ||||
-rw-r--r-- | setup.py | 3 | ||||
-rw-r--r-- | test/aaa_profiling/test_resultset.py | 5 | ||||
-rw-r--r-- | test/base/test_events.py | 136 | ||||
-rw-r--r-- | test/dialect/postgresql/test_dialect.py | 21 | ||||
-rw-r--r-- | test/dialect/test_mxodbc.py | 77 | ||||
-rw-r--r-- | test/engine/test_ddlemit.py | 30 | ||||
-rw-r--r-- | test/engine/test_execute.py | 24 | ||||
-rw-r--r-- | test/engine/test_parseconnect.py | 64 | ||||
-rw-r--r-- | test/engine/test_pool.py | 149 | ||||
-rw-r--r-- | test/engine/test_reconnect.py | 443 |
17 files changed, 502 insertions, 497 deletions
diff --git a/README.unittests.rst b/README.unittests.rst index ae7189854..7d052cfd7 100644 --- a/README.unittests.rst +++ b/README.unittests.rst @@ -7,12 +7,18 @@ module. If running on Python 2.4, pysqlite must be installed. Unit tests are run using nose. Nose is available at:: - http://pypi.python.org/pypi/nose/ + https://pypi.python.org/pypi/nose/ SQLAlchemy implements a nose plugin that must be present when tests are run. This plugin is invoked when the test runner script provided with SQLAlchemy is used. +The test suite as of version 0.8.2 also requires the mock library. While +mock is part of the Python standard library as of 3.3, previous versions +will need to have it installed, and is available at:: + + https://pypi.python.org/pypi/mock + **NOTE:** - the nose plugin is no longer installed by setuptools as of version 0.7 ! Use "python setup.py test" or "./sqla_nose.py". diff --git a/doc/build/changelog/changelog_08.rst b/doc/build/changelog/changelog_08.rst index c0e430ad6..d83d0618e 100644 --- a/doc/build/changelog/changelog_08.rst +++ b/doc/build/changelog/changelog_08.rst @@ -7,6 +7,15 @@ :version: 0.8.2 .. change:: + :tags: requirements + + The Python `mock <https://pypi.python.org/pypi/mock>`_ library + is now required in order to run the unit test suite. While part + of the standard library as of Python 3.3, previous Python installations + will need to install this in order to run unit tests or to + use the ``sqlalchemy.testing`` package for external dialects. + + .. change:: :tags: bug, orm :tickets: 2750 diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index af3081687..466444278 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -7,6 +7,16 @@ :version: 0.9.0 .. change:: + :tags: requirements + + The Python `mock <https://pypi.python.org/pypi/mock>`_ library + is now required in order to run the unit test suite. While part + of the standard library as of Python 3.3, previous Python installations + will need to install this in order to run unit tests or to + use the ``sqlalchemy.testing`` package for external dialects. + This applies to 0.8.2 as well. + + .. change:: :tags: bug, orm :tickets: 2750 diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index d5522213d..a87829499 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -18,3 +18,5 @@ from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict crashes = skip from .config import db, requirements as requires + +from . import mock
\ No newline at end of file diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py new file mode 100644 index 000000000..650962384 --- /dev/null +++ b/lib/sqlalchemy/testing/mock.py @@ -0,0 +1,15 @@ +"""Import stub for mock library. +""" +from __future__ import absolute_import +from ..util import py33 + +if py33: + from unittest.mock import MagicMock, Mock, call +else: + try: + from mock import MagicMock, Mock, call + except ImportError: + raise ImportError( + "SQLAlchemy's test suite requires the " + "'mock' library as of 0.8.2.") + diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 687abb39a..739caefe0 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -5,7 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from .compat import callable, cmp, reduce, \ - threading, py3k, py2k, jython, pypy, cpython, win32, \ + threading, py3k, py33, py2k, jython, pypy, cpython, win32, \ pickle, dottedgetter, parse_qsl, namedtuple, next, WeakSet, reraise, \ raise_from_cause, text_type, string_types, int_types, binary_type, \ quote_plus, with_metaclass, print_, itertools_filterfalse, u, ue, b,\ diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index fea22873c..d866534ab 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -13,6 +13,7 @@ try: except ImportError: import dummy_threading as threading +py33 = sys.version_info >= (3, 3) py32 = sys.version_info >= (3, 2) py3k = sys.version_info >= (3, 0) py2k = sys.version_info < (3, 0) @@ -119,8 +119,7 @@ def run_setup(with_cext): package_dir={'': 'lib'}, license="MIT License", cmdclass=cmdclass, - - tests_require=['nose >= 0.11'], + tests_require=['nose >= 0.11', 'mock'], test_suite="sqla_nose", long_description=readme, classifiers=[ diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 27e60410d..bbd8c4dba 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -3,6 +3,9 @@ from sqlalchemy.testing import fixtures, AssertsExecutionResults, profiling from sqlalchemy import testing from sqlalchemy.testing import eq_ from sqlalchemy.util import u +from sqlalchemy.engine.result import RowProxy +import sys + NUM_FIELDS = 10 NUM_RECORDS = 1000 @@ -80,7 +83,6 @@ class RowProxyTest(fixtures.TestBase): __requires__ = 'cpython', def _rowproxy_fixture(self, keys, processors, row): - from sqlalchemy.engine.result import RowProxy class MockMeta(object): def __init__(self): pass @@ -96,7 +98,6 @@ class RowProxyTest(fixtures.TestBase): return RowProxy(metadata, row, processors, keymap) def _test_getitem_value_refcounts(self, seq_factory): - import sys col1, col2 = object(), object() def proc1(value): return value diff --git a/test/base/test_events.py b/test/base/test_events.py index 7cfb5fa7d..20bfa62ff 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -5,6 +5,8 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, \ from sqlalchemy import event, exc from sqlalchemy.testing import fixtures from sqlalchemy.testing.util import gc_collect +from sqlalchemy.testing.mock import Mock, call + class EventsTest(fixtures.TestBase): """Test class- and instance-level event registration.""" @@ -190,7 +192,7 @@ class ClsLevelListenTest(fixtures.TestBase): def test_lis_subcalss_lis(self): @event.listens_for(self.TargetOne, "event_one") def handler1(x, y): - print('handler1') + pass class SubTarget(self.TargetOne): pass @@ -207,7 +209,7 @@ class ClsLevelListenTest(fixtures.TestBase): def test_lis_multisub_lis(self): @event.listens_for(self.TargetOne, "event_one") def handler1(x, y): - print('handler1') + pass class SubTarget(self.TargetOne): pass @@ -411,12 +413,8 @@ class ListenOverrideTest(fixtures.TestBase): event._remove_dispatcher(self.Target.__dict__['dispatch'].events) def test_listen_override(self): - result = [] - def listen_one(x): - result.append(x) - - def listen_two(x, y): - result.append((x, y)) + listen_one = Mock() + listen_two = Mock() event.listen(self.Target, "event_one", listen_one, add=True) event.listen(self.Target, "event_one", listen_two) @@ -425,10 +423,13 @@ class ListenOverrideTest(fixtures.TestBase): t1.dispatch.event_one(5, 7) t1.dispatch.event_one(10, 5) - eq_(result, - [ - 12, (5, 7), 15, (10, 5) - ] + eq_( + listen_one.mock_calls, + [call(12), call(15)] + ) + eq_( + listen_two.mock_calls, + [call(5, 7), call(10, 5)] ) class PropagateTest(fixtures.TestBase): @@ -446,12 +447,8 @@ class PropagateTest(fixtures.TestBase): def test_propagate(self): - result = [] - def listen_one(target, arg): - result.append((target, arg)) - - def listen_two(target, arg): - result.append((target, arg)) + listen_one = Mock() + listen_two = Mock() t1 = self.Target() @@ -464,7 +461,15 @@ class PropagateTest(fixtures.TestBase): t2.dispatch.event_one(t2, 1) t2.dispatch.event_two(t2, 2) - eq_(result, [(t2, 1)]) + + eq_( + listen_one.mock_calls, + [call(t2, 1)] + ) + eq_( + listen_two.mock_calls, + [] + ) class JoinTest(fixtures.TestBase): def setUp(self): @@ -497,12 +502,6 @@ class JoinTest(fixtures.TestBase): if 'dispatch' in cls.__dict__: event._remove_dispatcher(cls.__dict__['dispatch'].events) - def _listener(self): - canary = [] - def listen(target, arg): - canary.append((target, arg)) - return listen, canary - def test_neither(self): element = self.TargetFactory().create() element.run_event(1) @@ -510,22 +509,22 @@ class JoinTest(fixtures.TestBase): element.run_event(3) def test_parent_class_only(self): - _listener, canary = self._listener() + l1 = Mock() - event.listen(self.TargetFactory, "event_one", _listener) + event.listen(self.TargetFactory, "event_one", l1) element = self.TargetFactory().create() element.run_event(1) element.run_event(2) element.run_event(3) eq_( - canary, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_class_child_class(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetFactory, "event_one", l1) event.listen(self.TargetElement, "event_one", l2) @@ -535,17 +534,17 @@ class JoinTest(fixtures.TestBase): element.run_event(2) element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_class_child_instance_apply_after(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetFactory, "event_one", l1) element = self.TargetFactory().create() @@ -557,17 +556,17 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 2), (element, 3)] + l2.mock_calls, + [call(element, 2), call(element, 3)] ) def test_parent_class_child_instance_apply_before(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetFactory, "event_one", l1) element = self.TargetFactory().create() @@ -579,17 +578,17 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_instance_child_class_apply_before(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetElement, "event_one", l2) @@ -603,17 +602,18 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) + def test_parent_instance_child_class_apply_after(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() event.listen(self.TargetElement, "event_one", l2) @@ -632,18 +632,16 @@ class JoinTest(fixtures.TestBase): # this can be changed to be "live" at the cost # of performance. eq_( - c1, - [] - #(element, 2), (element, 3)] + l1.mock_calls, [] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_instance_child_instance_apply_before(self): - l1, c1 = self._listener() - l2, c2 = self._listener() + l1 = Mock() + l2 = Mock() factory = self.TargetFactory() event.listen(factory, "event_one", l1) @@ -656,16 +654,16 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) eq_( - c2, - [(element, 1), (element, 2), (element, 3)] + l2.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) def test_parent_events_child_no_events(self): - l1, c1 = self._listener() + l1 = Mock() factory = self.TargetFactory() event.listen(self.TargetElement, "event_one", l1) @@ -676,6 +674,6 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( - c1, - [(element, 1), (element, 2), (element, 3)] + l1.mock_calls, + [call(element, 1), call(element, 2), call(element, 3)] ) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 86ce91dc9..1fc239cb7 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -16,6 +16,7 @@ from sqlalchemy import exc, schema from sqlalchemy.dialects.postgresql import base as postgresql import logging import logging.handlers +from sqlalchemy.testing.mock import Mock class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): @@ -37,18 +38,12 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): 'The JDBC driver handles the version parsing') def test_version_parsing(self): - - class MockConn(object): - - def __init__(self, res): - self.res = res - - def execute(self, str): - return self - - def scalar(self): - return self.res - + def mock_conn(res): + return Mock( + execute=Mock( + return_value=Mock(scalar=Mock(return_value=res)) + ) + ) for string, version in \ [('PostgreSQL 8.3.8 on i686-redhat-linux-gnu, compiled by ' @@ -59,7 +54,7 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): ('EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, ' 'compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), ' '64-bit', (9, 1, 2))]: - eq_(testing.db.dialect._get_server_version_info(MockConn(string)), + eq_(testing.db.dialect._get_server_version_info(mock_conn(string)), version) @testing.only_on('postgresql+psycopg2', 'psycopg2-specific feature') diff --git a/test/dialect/test_mxodbc.py b/test/dialect/test_mxodbc.py index 32cad4168..e46de9149 100644 --- a/test/dialect/test_mxodbc.py +++ b/test/dialect/test_mxodbc.py @@ -2,75 +2,48 @@ from sqlalchemy import * from sqlalchemy.testing import eq_ from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures - -# TODO: we should probably build mock bases for -# these to share with test_reconnect, test_parseconnect -class MockDBAPI(object): - paramstyle = 'qmark' - def __init__(self): - self.log = [] - def connect(self, *args, **kwargs): - return MockConnection(self) - -class MockConnection(object): - def __init__(self, parent): - self.parent = parent - def cursor(self): - return MockCursor(self) - def close(self): - pass - def rollback(self): - pass - def commit(self): - pass - -class MockCursor(object): - description = None - rowcount = None - def __init__(self, parent): - self.parent = parent - def execute(self, *args, **kwargs): - if kwargs.get('direct', False): - self.executedirect() - else: - self.parent.parent.log.append('execute') - def executedirect(self, *args, **kwargs): - self.parent.parent.log.append('executedirect') - def close(self): - pass +from sqlalchemy.testing.mock import Mock + +def mock_dbapi(): + return Mock(paramstyle='qmark', + connect=Mock( + return_value=Mock( + cursor=Mock( + return_value=Mock( + description=None, + rowcount=None) + ) + ) + ) + ) class MxODBCTest(fixtures.TestBase): def test_native_odbc_execute(self): t1 = Table('t1', MetaData(), Column('c1', Integer)) - dbapi = MockDBAPI() + dbapi = mock_dbapi() + engine = engines.testing_engine('mssql+mxodbc://localhost', options={'module': dbapi, '_initialize': False}) conn = engine.connect() # crud: uses execute - conn.execute(t1.insert().values(c1='foo')) conn.execute(t1.delete().where(t1.c.c1 == 'foo')) - conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar' - )) + conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar')) # select: uses executedirect - conn.execute(t1.select()) # manual flagging - conn.execution_options(native_odbc_execute=True).\ execute(t1.select()) conn.execution_options(native_odbc_execute=False).\ - execute(t1.insert().values(c1='foo' - )) - eq_(dbapi.log, [ - 'executedirect', - 'executedirect', - 'executedirect', - 'executedirect', - 'execute', - 'executedirect', - ]) + execute(t1.insert().values(c1='foo')) + + eq_( + [c[2] for c in + dbapi.connect.return_value.cursor.return_value.execute.mock_calls], + [{'direct': True}, {'direct': True}, {'direct': True}, + {'direct': True}, {'direct': False}, {'direct': True}] + ) diff --git a/test/engine/test_ddlemit.py b/test/engine/test_ddlemit.py index deaf09cf7..e773d0ced 100644 --- a/test/engine/test_ddlemit.py +++ b/test/engine/test_ddlemit.py @@ -3,28 +3,19 @@ from sqlalchemy.engine.ddl import SchemaGenerator, SchemaDropper from sqlalchemy.engine import default from sqlalchemy import MetaData, Table, Column, Integer, Sequence from sqlalchemy import schema +from sqlalchemy.testing.mock import Mock class EmitDDLTest(fixtures.TestBase): def _mock_connection(self, item_exists): - _canary = [] + def has_item(connection, name, schema): + return item_exists(name) - class MockDialect(default.DefaultDialect): - supports_sequences = True - - def has_table(self, connection, name, schema): - return item_exists(name) - - def has_sequence(self, connection, name, schema): - return item_exists(name) - - class MockConnection(object): - dialect = MockDialect() - canary = _canary - - def execute(self, item): - _canary.append(item) - - return MockConnection() + return Mock(dialect=Mock( + supports_sequences=True, + has_table=Mock(side_effect=has_item), + has_sequence=Mock(side_effect=has_item) + ) + ) def _mock_create_fixture(self, checkfirst, tables, item_exists=lambda item: False): @@ -176,7 +167,8 @@ class EmitDDLTest(fixtures.TestBase): def _assert_ddl(self, ddl_cls, elements, generator, argument): generator.traverse_single(argument) - for c in generator.connection.canary: + for call_ in generator.connection.execute.mock_calls: + c = call_[1][0] assert isinstance(c, ddl_cls) assert c.element in elements, "element %r was not expected"\ % c.element diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 1c577730b..9795e4c10 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -19,6 +19,8 @@ from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam from sqlalchemy.engine import result as _result, default from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.testing import fixtures +from sqlalchemy.testing.mock import Mock, call + users, metadata, users_autoinc = None, None, None class ExecuteTest(fixtures.TestBase): @@ -455,20 +457,22 @@ class ConvenienceExecuteTest(fixtures.TablesTest): def test_transaction_engine_ctx_begin_fails(self): engine = engines.testing_engine() - class MockConnection(Connection): - closed = False - def begin(self): - raise Exception("boom") - - def close(self): - MockConnection.closed = True - engine._connection_cls = MockConnection - fn = self._trans_fn() + + mock_connection = Mock( + return_value=Mock( + begin=Mock(side_effect=Exception("boom")) + ) + ) + engine._connection_cls = mock_connection assert_raises( Exception, engine.begin ) - assert MockConnection.closed + + eq_( + mock_connection.return_value.close.mock_calls, + [call()] + ) def test_transaction_engine_ctx_rollback(self): fn = self._trans_rollback_fn() diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 73bdc76c4..106bd0782 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -7,6 +7,8 @@ from sqlalchemy.engine.default import DefaultDialect import sqlalchemy as tsa from sqlalchemy.testing import fixtures from sqlalchemy import testing +from sqlalchemy.testing.mock import Mock + class ParseConnectTest(fixtures.TestBase): def test_rfc1738(self): @@ -250,20 +252,17 @@ pool_timeout=10 every backend. """ - # pretend pysqlite throws the - # "Cannot operate on a closed database." error - # on connect. IRL we'd be getting Oracle's "shutdown in progress" e = create_engine('sqlite://') sqlite3 = e.dialect.dbapi - class ThrowOnConnect(MockDBAPI): - dbapi = sqlite3 - Error = sqlite3.Error - ProgrammingError = sqlite3.ProgrammingError - def connect(self, *args, **kw): - raise sqlite3.ProgrammingError("Cannot operate on a closed database.") + + dbapi = MockDBAPI() + dbapi.Error = sqlite3.Error, + dbapi.ProgrammingError = sqlite3.ProgrammingError + dbapi.connect = Mock(side_effect=sqlite3.ProgrammingError( + "Cannot operate on a closed database.")) try: - create_engine('sqlite://', module=ThrowOnConnect()).connect() + create_engine('sqlite://', module=dbapi).connect() assert False except tsa.exc.DBAPIError as de: assert de.connection_invalidated @@ -354,36 +353,23 @@ class MockDialect(DefaultDialect): def dbapi(cls, **kw): return MockDBAPI() -class MockDBAPI(object): - version_info = sqlite_version_info = 99, 9, 9 - sqlite_version = '99.9.9' - - def __init__(self, **kwargs): - self.kwargs = kwargs - self.paramstyle = 'named' - - def connect(self, *args, **kwargs): - for k in self.kwargs: +def MockDBAPI(**assert_kwargs): + connection = Mock(get_server_version_info=Mock(return_value='5.0')) + def connect(*args, **kwargs): + for k in assert_kwargs: assert k in kwargs, 'key %s not present in dictionary' % k - assert kwargs[k] == self.kwargs[k], \ - 'value %s does not match %s' % (kwargs[k], - self.kwargs[k]) - return MockConnection() - - -class MockConnection(object): - def get_server_info(self): - return '5.0' - - def close(self): - pass - - def cursor(self): - return MockCursor() - -class MockCursor(object): - def close(self): - pass + eq_( + kwargs[k], assert_kwargs[k] + ) + return connection + + return Mock( + sqlite_version_info=(99, 9, 9,), + version_info=(99, 9, 9,), + sqlite_version='99.9.9', + paramstyle='named', + connect=Mock(side_effect=connect) + ) mock_dbapi = MockDBAPI() mock_sqlite_dbapi = msd = MockDBAPI() diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 583978465..981df6dd0 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -4,37 +4,21 @@ from sqlalchemy import pool, select, event import sqlalchemy as tsa from sqlalchemy import testing from sqlalchemy.testing.util import gc_collect, lazy_gc -from sqlalchemy.testing import eq_, assert_raises +from sqlalchemy.testing import eq_, assert_raises, is_not_ from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing import fixtures -mcid = 1 -class MockDBAPI(object): - throw_error = False - def connect(self, *args, **kwargs): - if self.throw_error: - raise Exception("couldnt connect !") - delay = kwargs.pop('delay', 0) - if delay: - time.sleep(delay) - return MockConnection() -class MockConnection(object): - closed = False - def __init__(self): - global mcid - self.id = mcid - mcid += 1 - def close(self): - self.closed = True - def rollback(self): - pass - def cursor(self): - return MockCursor() -class MockCursor(object): - def execute(self, *args, **kw): - pass - def close(self): - pass +from sqlalchemy.testing.mock import Mock, call + +def MockDBAPI(): + def cursor(): + while True: + yield Mock() + def connect(): + while True: + yield Mock(cursor=Mock(side_effect=cursor())) + + return Mock(connect=Mock(side_effect=connect())) class PoolTestBase(fixtures.TestBase): def setup(self): @@ -71,11 +55,9 @@ class PoolTest(PoolTestBase): assert c4 is not c5 def test_manager_with_key(self): - class NoKws(object): - def connect(self, arg): - return MockConnection() - manager = pool.manage(NoKws(), use_threadlocal=True) + dbapi = MockDBAPI() + manager = pool.manage(dbapi, use_threadlocal=True) c1 = manager.connect('foo.db', sa_pool_key="a") c2 = manager.connect('foo.db', sa_pool_key="b") @@ -83,9 +65,14 @@ class PoolTest(PoolTestBase): assert c1.cursor() is not None assert c1 is not c2 - assert c1 is c3 - + assert c1 is c3 + eq_(dbapi.connect.mock_calls, + [ + call("foo.db"), + call("foo.db"), + ] + ) def test_bad_args(self): @@ -127,7 +114,7 @@ class PoolTest(PoolTestBase): p = cls(creator=mock_dbapi.connect) conn = p.connect() conn.close() - mock_dbapi.throw_error = True + mock_dbapi.connect.side_effect = Exception("error!") p.dispose() p.recreate() @@ -211,9 +198,9 @@ class PoolTest(PoolTestBase): self.assert_('foo2' in c.info) c2 = p.connect() - self.assert_(c.connection is not c2.connection) - self.assert_(not c2.info) - self.assert_('foo2' in c.info) + is_not_(c.connection, c2.connection) + assert not c2.info + assert 'foo2' in c.info class PoolDialectTest(PoolTestBase): @@ -945,19 +932,24 @@ class QueuePoolTest(PoolTestBase): def test_dispose_closes_pooled(self): dbapi = MockDBAPI() - def creator(): - return dbapi.connect() - p = pool.QueuePool(creator=creator, + p = pool.QueuePool(creator=dbapi.connect, pool_size=2, timeout=None, max_overflow=0) c1 = p.connect() c2 = p.connect() - conns = [c1.connection, c2.connection] + c1_con = c1.connection + c2_con = c2.connection + c1.close() - eq_([c.closed for c in conns], [False, False]) + + eq_(c1_con.close.call_count, 0) + eq_(c2_con.close.call_count, 0) + p.dispose() - eq_([c.closed for c in conns], [True, False]) + + eq_(c1_con.close.call_count, 1) + eq_(c2_con.close.call_count, 0) # currently, if a ConnectionFairy is closed # after the pool has been disposed, there's no @@ -965,11 +957,12 @@ class QueuePoolTest(PoolTestBase): # immediately - it just gets returned to the # pool normally... c2.close() - eq_([c.closed for c in conns], [True, False]) + eq_(c1_con.close.call_count, 1) + eq_(c2_con.close.call_count, 0) # ...and that's the one we'll get back next. c3 = p.connect() - assert c3.connection is conns[1] + assert c3.connection is c2_con def test_no_overflow(self): self._test_overflow(40, 0) @@ -1010,12 +1003,21 @@ class QueuePoolTest(PoolTestBase): return c for j in range(5): + # open 4 conns at a time. each time this + # will yield two pooled connections + two + # overflow connections. conns = [_conn() for i in range(4)] for c in conns: c.close() - still_opened = len([c for c in strong_refs if not c.closed]) - eq_(still_opened, 2) + # doing that for a total of 5 times yields + # ten overflow connections closed plus the + # two pooled connections unclosed. + + eq_( + set([c.close.call_count for c in strong_refs]), + set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0]) + ) @testing.requires.predictable_gc def test_weakref_kaboom(self): @@ -1108,18 +1110,30 @@ class QueuePoolTest(PoolTestBase): dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0) c1 = p.connect() c1.detach() - c_id = c1.connection.id - c2 = p.connect() - assert c2.connection.id != c1.connection.id - dbapi.raise_error = True - c2.invalidate() - c2 = None c2 = p.connect() - assert c2.connection.id != c1.connection.id - con = c1.connection - assert not con.closed + eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")]) + + c1_con = c1.connection + assert c1_con is not None + eq_(c1_con.close.call_count, 0) c1.close() - assert con.closed + eq_(c1_con.close.call_count, 1) + + def test_detach_via_invalidate(self): + dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0) + + c1 = p.connect() + c1_con = c1.connection + c1.invalidate() + assert c1.connection is None + eq_(c1_con.close.call_count, 1) + + c2 = p.connect() + assert c2.connection is not c1_con + c2_con = c2.connection + + c2.close() + eq_(c2_con.close.call_count, 0) def test_threadfairy(self): p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True) @@ -1141,8 +1155,13 @@ class SingletonThreadPoolTest(PoolTestBase): been called.""" dbapi = MockDBAPI() - p = pool.SingletonThreadPool(creator=dbapi.connect, - pool_size=3) + + lock = threading.Lock() + def creator(): + # the mock iterator isn't threadsafe... + with lock: + return dbapi.connect() + p = pool.SingletonThreadPool(creator=creator, pool_size=3) if strong_refs: sr = set() @@ -1172,7 +1191,7 @@ class SingletonThreadPoolTest(PoolTestBase): assert len(p._all_conns) == 3 if strong_refs: - still_opened = len([c for c in sr if not c.closed]) + still_opened = len([c for c in sr if not c.close.call_count]) eq_(still_opened, 3) class AssertionPoolTest(PoolTestBase): @@ -1198,17 +1217,19 @@ class NullPoolTest(PoolTestBase): dbapi = MockDBAPI() p = pool.NullPool(creator=lambda: dbapi.connect('foo.db')) c1 = p.connect() - c_id = c1.connection.id + c1.close() c1 = None c1 = p.connect() - dbapi.raise_error = True c1.invalidate() c1 = None c1 = p.connect() - assert c1.connection.id != c_id + dbapi.connect.assert_has_calls([ + call('foo.db'), + call('foo.db')], + any_order=True) class StaticPoolTest(PoolTestBase): diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index ee3ff1459..86003bec6 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,7 +1,6 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message import time -import weakref -from sqlalchemy import select, MetaData, Integer, String, pool, create_engine +from sqlalchemy import select, MetaData, Integer, String, create_engine, pool from sqlalchemy.testing.schema import Table, Column import sqlalchemy as tsa from sqlalchemy import testing @@ -10,6 +9,8 @@ from sqlalchemy.testing.util import gc_collect from sqlalchemy import exc, util from sqlalchemy.testing import fixtures from sqlalchemy.testing.engines import testing_engine +from sqlalchemy.testing import is_not_ +from sqlalchemy.testing.mock import Mock, call class MockError(Exception): pass @@ -17,93 +18,103 @@ class MockError(Exception): class MockDisconnect(MockError): pass -class MockDBAPI(object): - def __init__(self): - self.paramstyle = 'named' - self.connections = weakref.WeakKeyDictionary() - def connect(self, *args, **kwargs): - return MockConnection(self) - def shutdown(self, explode='execute'): - for c in self.connections: - c.explode = explode - Error = MockError - -class MockConnection(object): - def __init__(self, dbapi): - dbapi.connections[self] = True - self.explode = "" - def rollback(self): - if self.explode == 'rollback': +def mock_connection(): + def mock_cursor(): + def execute(*args, **kwargs): + if conn.explode == 'execute': + raise MockDisconnect("Lost the DB connection on execute") + elif conn.explode in ('execute_no_disconnect', ): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif conn.explode in ('rollback', 'rollback_no_disconnect'): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif args and "SELECT" in args[0]: + cursor.description = [('foo', None, None, None, None, None)] + else: + return + + def close(): + cursor.fetchall = cursor.fetchone = \ + Mock(side_effect=MockError("cursor closed")) + cursor = Mock( + execute=Mock(side_effect=execute), + close=Mock(side_effect=close) + ) + return cursor + + def cursor(): + while True: + yield mock_cursor() + + def rollback(): + if conn.explode == 'rollback': raise MockDisconnect("Lost the DB connection on rollback") - if self.explode == 'rollback_no_disconnect': + if conn.explode == 'rollback_no_disconnect': raise MockError( "something broke on rollback but we didn't lose the connection") else: return - def commit(self): - pass - def cursor(self): - return MockCursor(self) - def close(self): - pass - -class MockCursor(object): - def __init__(self, parent): - self.explode = parent.explode - self.description = () - self.closed = False - def execute(self, *args, **kwargs): - if self.explode == 'execute': - raise MockDisconnect("Lost the DB connection on execute") - elif self.explode in ('execute_no_disconnect', ): - raise MockError( - "something broke on execute but we didn't lose the connection") - elif self.explode in ('rollback', 'rollback_no_disconnect'): - raise MockError( - "something broke on execute but we didn't lose the connection") - elif args and "select" in args[0]: - self.description = [('foo', None, None, None, None, None)] - else: - return - def fetchall(self): - if self.closed: - raise MockError("cursor closed") - return [] - def fetchone(self): - if self.closed: - raise MockError("cursor closed") - return None - def close(self): - self.closed = True - -db, dbapi = None, None + + conn = Mock( + rollback=Mock(side_effect=rollback), + cursor=Mock(side_effect=cursor()) + ) + return conn + +def MockDBAPI(): + connections = [] + def connect(): + while True: + conn = mock_connection() + connections.append(conn) + yield conn + + def shutdown(explode='execute'): + for c in connections: + c.explode = explode + + def dispose(): + for c in connections: + c.explode = None + connections[:] = [] + + return Mock( + connect=Mock(side_effect=connect()), + shutdown=Mock(side_effect=shutdown), + dispose=Mock(side_effect=dispose), + paramstyle='named', + connections=connections, + Error=MockError + ) + + class MockReconnectTest(fixtures.TestBase): def setup(self): - global db, dbapi - dbapi = MockDBAPI() + self.dbapi = MockDBAPI() - # note - using straight create_engine here - # since we are testing gc - db = create_engine( + self.db = testing_engine( 'postgresql://foo:bar@localhost/test', - module=dbapi, _initialize=False) + options=dict(module=self.dbapi, _initialize=False)) + self.mock_connect = call(host='localhost', password='bar', + user='foo', database='test') # monkeypatch disconnect checker - db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect) + self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect) def teardown(self): - db.dispose() + self.dbapi.dispose() def test_reconnect(self): """test that an 'is_disconnect' condition will invalidate the connection, and additionally dispose the previous connection pool and recreate.""" - pid = id(db.pool) + db_pool = self.db.pool # make a connection - conn = db.connect() + conn = self.db.connect() # connection works @@ -112,21 +123,20 @@ class MockReconnectTest(fixtures.TestBase): # create a second connection within the pool, which we'll ensure # also goes away - conn2 = db.connect() + conn2 = self.db.connect() conn2.close() # two connections opened total now - assert len(dbapi.connections) == 2 + assert len(self.dbapi.connections) == 2 # set it to fail - dbapi.shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + self.dbapi.shutdown() + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) # assert was invalidated @@ -136,31 +146,38 @@ class MockReconnectTest(fixtures.TestBase): # close shouldnt break conn.close() - assert id(db.pool) != pid + is_not_(self.db.pool, db_pool) # ensure all connections closed (pool was recycled) - gc_collect() - assert len(dbapi.connections) == 0 - conn = db.connect() + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], [call()]] + ) + + conn = self.db.connect() conn.execute(select([1])) conn.close() - assert len(dbapi.connections) == 1 + + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], [call()], []] + ) def test_invalidate_trans(self): - conn = db.connect() + conn = self.db.connect() trans = conn.begin() - dbapi.shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + self.dbapi.shutdown() - # assert was invalidated + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) - gc_collect() - assert len(dbapi.connections) == 0 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()]] + ) assert not conn.closed assert conn.invalidated assert trans.is_active @@ -170,28 +187,35 @@ class MockReconnectTest(fixtures.TestBase): conn.execute, select([1]) ) assert trans.is_active - try: - trans.commit() - assert False - except tsa.exc.InvalidRequestError as e: - assert str(e) \ - == "Can't reconnect until invalid transaction is "\ - "rolled back" + + assert_raises_message( + tsa.exc.InvalidRequestError, + "Can't reconnect until invalid transaction is " + "rolled back", + trans.commit + ) + assert trans.is_active trans.rollback() assert not trans.is_active conn.execute(select([1])) assert not conn.invalidated - assert len(dbapi.connections) == 1 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], []] + ) def test_conn_reusable(self): - conn = db.connect() + conn = self.db.connect() conn.execute(select([1])) - assert len(dbapi.connections) == 1 + eq_( + self.dbapi.connect.mock_calls, + [self.mock_connect] + ) - dbapi.shutdown() + self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, @@ -201,19 +225,24 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.closed assert conn.invalidated - # ensure all connections closed (pool was recycled) - gc_collect() - assert len(dbapi.connections) == 0 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()]] + ) # test reconnects conn.execute(select([1])) assert not conn.invalidated - assert len(dbapi.connections) == 1 + + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], []] + ) def test_invalidated_close(self): - conn = db.connect() + conn = self.db.connect() - dbapi.shutdown() + self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, @@ -230,9 +259,9 @@ class MockReconnectTest(fixtures.TestBase): ) def test_noreconnect_execute_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("execute_no_disconnect") + self.dbapi.shutdown("execute_no_disconnect") # raises error assert_raises_message( @@ -245,9 +274,9 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.invalidated def test_noreconnect_rollback_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("rollback_no_disconnect") + self.dbapi.shutdown("rollback_no_disconnect") # raises error assert_raises_message( @@ -266,13 +295,13 @@ class MockReconnectTest(fixtures.TestBase): ) def test_reconnect_on_reentrant(self): - conn = db.connect() + conn = self.db.connect() conn.execute(select([1])) - assert len(dbapi.connections) == 1 + assert len(self.dbapi.connections) == 1 - dbapi.shutdown("rollback") + self.dbapi.shutdown("rollback") # raises error assert_raises_message( @@ -285,9 +314,9 @@ class MockReconnectTest(fixtures.TestBase): assert conn.invalidated def test_reconnect_on_reentrant_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("rollback") + self.dbapi.shutdown("rollback") # raises error assert_raises_message( @@ -306,10 +335,11 @@ class MockReconnectTest(fixtures.TestBase): ) def test_check_disconnect_no_cursor(self): - conn = db.connect() - result = conn.execute("select 1") + conn = self.db.connect() + result = conn.execute(select([1])) result.cursor.close() conn.close() + assert_raises_message( tsa.exc.DBAPIError, "cursor closed", @@ -319,60 +349,59 @@ class MockReconnectTest(fixtures.TestBase): class CursorErrTest(fixtures.TestBase): def setup(self): - global db, dbapi - - class MDBAPI(MockDBAPI): - def connect(self, *args, **kwargs): - return MConn(self) - - class MConn(MockConnection): - def cursor(self): - return MCursor(self) + def MockDBAPI(): + def cursor(): + while True: + yield Mock( + description=[], + close=Mock(side_effect=Exception("explode"))) + def connect(): + while True: + yield Mock(cursor=Mock(side_effect=cursor())) + + return Mock(connect=Mock(side_effect=connect())) - class MCursor(MockCursor): - def close(self): - raise Exception("explode") - - dbapi = MDBAPI() - - db = testing_engine( + dbapi = MockDBAPI() + self.db = testing_engine( 'postgresql://foo:bar@localhost/test', options=dict(module=dbapi, _initialize=False)) def test_cursor_explode(self): - conn = db.connect() + conn = self.db.connect() result = conn.execute("select foo") result.close() conn.close() def teardown(self): - db.dispose() + self.db.dispose() + + +def _assert_invalidated(fn, *args): + try: + fn(*args) + assert False + except tsa.exc.DBAPIError as e: + if not e.connection_invalidated: + raise -engine = None class RealReconnectTest(fixtures.TestBase): def setup(self): - global engine - engine = engines.reconnecting_engine() + self.engine = engines.reconnecting_engine() def teardown(self): - engine.dispose() + self.engine.dispose() @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_reconnect(self): - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() + self.engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated @@ -382,13 +411,9 @@ class RealReconnectTest(fixtures.TestBase): assert not conn.invalidated # one more time - engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select([1])) + assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated @@ -396,30 +421,22 @@ class RealReconnectTest(fixtures.TestBase): conn.close() def test_multiple_invalidate(self): - c1 = engine.connect() - c2 = engine.connect() + c1 = self.engine.connect() + c2 = self.engine.connect() eq_(c1.execute(select([1])).scalar(), 1) - p1 = engine.pool - engine.test_shutdown() + p1 = self.engine.pool + self.engine.test_shutdown() - try: - c1.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - assert e.connection_invalidated + _assert_invalidated(c1.execute, select([1])) - p2 = engine.pool + p2 = self.engine.pool - try: - c2.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - assert e.connection_invalidated + _assert_invalidated(c2.execute, select([1])) # pool isn't replaced - assert engine.pool is p2 + assert self.engine.pool is p2 def test_ensure_is_disconnect_gets_connection(self): @@ -430,37 +447,37 @@ class RealReconnectTest(fixtures.TestBase): # though MySQLdb we get a non-working cursor. # assert cursor is None - engine.dialect.is_disconnect = is_disconnect - conn = engine.connect() - engine.test_shutdown() + self.engine.dialect.is_disconnect = is_disconnect + conn = self.engine.connect() + self.engine.test_shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) def test_rollback_on_invalid_plain(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() conn.invalidate() trans.rollback() @testing.requires.two_phase_transactions def test_rollback_on_invalid_twophase(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin_twophase() conn.invalidate() trans.rollback() @testing.requires.savepoints def test_rollback_on_invalid_savepoint(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() trans2 = conn.begin_nested() conn.invalidate() trans2.rollback() def test_invalidate_twice(self): - conn = engine.connect() + conn = self.engine.connect() conn.invalidate() conn.invalidate() @@ -503,12 +520,7 @@ class RealReconnectTest(fixtures.TestBase): eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) @@ -517,37 +529,27 @@ class RealReconnectTest(fixtures.TestBase): @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_close(self): - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() + self.engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) conn.close() - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_with_transaction(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated assert trans.is_active @@ -558,13 +560,11 @@ class RealReconnectTest(fixtures.TestBase): conn.execute, select([1]) ) assert trans.is_active - try: - trans.commit() - assert False - except tsa.exc.InvalidRequestError as e: - assert str(e) \ - == "Can't reconnect until invalid transaction is "\ - "rolled back" + assert_raises_message( + tsa.exc.InvalidRequestError, + "Can't reconnect until invalid transaction is rolled back", + trans.commit + ) assert trans.is_active trans.rollback() assert not trans.is_active @@ -602,23 +602,21 @@ class RecycleTest(fixtures.TestBase): eq_(conn.execute(select([1])).scalar(), 1) conn.close() -meta, table, engine = None, None, None class InvalidateDuringResultTest(fixtures.TestBase): def setup(self): - global meta, table, engine - engine = engines.reconnecting_engine() - meta = MetaData(engine) - table = Table('sometable', meta, + self.engine = engines.reconnecting_engine() + self.meta = MetaData(self.engine) + table = Table('sometable', self.meta, Column('id', Integer, primary_key=True), Column('name', String(50))) - meta.create_all() + self.meta.create_all() table.insert().execute( - [{'id':i, 'name':'row %d' % i} for i in range(1, 100)] + [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)] ) def teardown(self): - meta.drop_all() - engine.dispose() + self.meta.drop_all() + self.engine.dispose() @testing.fails_if([ '+mysqlconnector', '+mysqldb', @@ -628,16 +626,11 @@ class InvalidateDuringResultTest(fixtures.TestBase): @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_invalidate_on_results(self): - conn = engine.connect() + conn = self.engine.connect() result = conn.execute('select * from sometable') for x in range(20): result.fetchone() - engine.test_shutdown() - try: - print('ghost result: %r' % result.fetchone()) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(result.fetchone) assert conn.invalidated |