summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2013-06-30 18:35:12 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2013-06-30 18:35:12 -0400
commitb38a76cd1d47cd6b8f1abef30ad7c3aeaa27d537 (patch)
tree7af1dba9e242c77a248cb2194434aa9bf3ca49b7
parent715d6cf3d10a71acd7726b7e00c3ff40b4559bc7 (diff)
downloadsqlalchemy-b38a76cd1d47cd6b8f1abef30ad7c3aeaa27d537.tar.gz
- replace most explicitly-named test objects called "Mock..." with
actual mock objects from the mock library. I'd like to use mock for new tests so we might as well use it in obvious places. - use unittest.mock in py3.3 - changelog - add a note to README.unittests - add tests_require in setup.py - have tests import from sqlalchemy.testing.mock - apply usage of mock to one of the event tests. we can be using this approach all over the place.
-rw-r--r--README.unittests.rst8
-rw-r--r--doc/build/changelog/changelog_08.rst9
-rw-r--r--doc/build/changelog/changelog_09.rst10
-rw-r--r--lib/sqlalchemy/testing/__init__.py2
-rw-r--r--lib/sqlalchemy/testing/mock.py15
-rw-r--r--lib/sqlalchemy/util/__init__.py2
-rw-r--r--lib/sqlalchemy/util/compat.py1
-rw-r--r--setup.py3
-rw-r--r--test/aaa_profiling/test_resultset.py5
-rw-r--r--test/base/test_events.py136
-rw-r--r--test/dialect/postgresql/test_dialect.py21
-rw-r--r--test/dialect/test_mxodbc.py77
-rw-r--r--test/engine/test_ddlemit.py30
-rw-r--r--test/engine/test_execute.py24
-rw-r--r--test/engine/test_parseconnect.py64
-rw-r--r--test/engine/test_pool.py149
-rw-r--r--test/engine/test_reconnect.py443
17 files changed, 502 insertions, 497 deletions
diff --git a/README.unittests.rst b/README.unittests.rst
index ae7189854..7d052cfd7 100644
--- a/README.unittests.rst
+++ b/README.unittests.rst
@@ -7,12 +7,18 @@ module. If running on Python 2.4, pysqlite must be installed.
Unit tests are run using nose. Nose is available at::
- http://pypi.python.org/pypi/nose/
+ https://pypi.python.org/pypi/nose/
SQLAlchemy implements a nose plugin that must be present when tests are run.
This plugin is invoked when the test runner script provided with
SQLAlchemy is used.
+The test suite as of version 0.8.2 also requires the mock library. While
+mock is part of the Python standard library as of 3.3, previous versions
+will need to have it installed, and is available at::
+
+ https://pypi.python.org/pypi/mock
+
**NOTE:** - the nose plugin is no longer installed by setuptools as of
version 0.7 ! Use "python setup.py test" or "./sqla_nose.py".
diff --git a/doc/build/changelog/changelog_08.rst b/doc/build/changelog/changelog_08.rst
index c0e430ad6..d83d0618e 100644
--- a/doc/build/changelog/changelog_08.rst
+++ b/doc/build/changelog/changelog_08.rst
@@ -7,6 +7,15 @@
:version: 0.8.2
.. change::
+ :tags: requirements
+
+ The Python `mock <https://pypi.python.org/pypi/mock>`_ library
+ is now required in order to run the unit test suite. While part
+ of the standard library as of Python 3.3, previous Python installations
+ will need to install this in order to run unit tests or to
+ use the ``sqlalchemy.testing`` package for external dialects.
+
+ .. change::
:tags: bug, orm
:tickets: 2750
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst
index af3081687..466444278 100644
--- a/doc/build/changelog/changelog_09.rst
+++ b/doc/build/changelog/changelog_09.rst
@@ -7,6 +7,16 @@
:version: 0.9.0
.. change::
+ :tags: requirements
+
+ The Python `mock <https://pypi.python.org/pypi/mock>`_ library
+ is now required in order to run the unit test suite. While part
+ of the standard library as of Python 3.3, previous Python installations
+ will need to install this in order to run unit tests or to
+ use the ``sqlalchemy.testing`` package for external dialects.
+ This applies to 0.8.2 as well.
+
+ .. change::
:tags: bug, orm
:tickets: 2750
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index d5522213d..a87829499 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -18,3 +18,5 @@ from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
crashes = skip
from .config import db, requirements as requires
+
+from . import mock \ No newline at end of file
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
new file mode 100644
index 000000000..650962384
--- /dev/null
+++ b/lib/sqlalchemy/testing/mock.py
@@ -0,0 +1,15 @@
+"""Import stub for mock library.
+"""
+from __future__ import absolute_import
+from ..util import py33
+
+if py33:
+ from unittest.mock import MagicMock, Mock, call
+else:
+ try:
+ from mock import MagicMock, Mock, call
+ except ImportError:
+ raise ImportError(
+ "SQLAlchemy's test suite requires the "
+ "'mock' library as of 0.8.2.")
+
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 687abb39a..739caefe0 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -5,7 +5,7 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .compat import callable, cmp, reduce, \
- threading, py3k, py2k, jython, pypy, cpython, win32, \
+ threading, py3k, py33, py2k, jython, pypy, cpython, win32, \
pickle, dottedgetter, parse_qsl, namedtuple, next, WeakSet, reraise, \
raise_from_cause, text_type, string_types, int_types, binary_type, \
quote_plus, with_metaclass, print_, itertools_filterfalse, u, ue, b,\
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index fea22873c..d866534ab 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -13,6 +13,7 @@ try:
except ImportError:
import dummy_threading as threading
+py33 = sys.version_info >= (3, 3)
py32 = sys.version_info >= (3, 2)
py3k = sys.version_info >= (3, 0)
py2k = sys.version_info < (3, 0)
diff --git a/setup.py b/setup.py
index 97212b55e..5b506f529 100644
--- a/setup.py
+++ b/setup.py
@@ -119,8 +119,7 @@ def run_setup(with_cext):
package_dir={'': 'lib'},
license="MIT License",
cmdclass=cmdclass,
-
- tests_require=['nose >= 0.11'],
+ tests_require=['nose >= 0.11', 'mock'],
test_suite="sqla_nose",
long_description=readme,
classifiers=[
diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py
index 27e60410d..bbd8c4dba 100644
--- a/test/aaa_profiling/test_resultset.py
+++ b/test/aaa_profiling/test_resultset.py
@@ -3,6 +3,9 @@ from sqlalchemy.testing import fixtures, AssertsExecutionResults, profiling
from sqlalchemy import testing
from sqlalchemy.testing import eq_
from sqlalchemy.util import u
+from sqlalchemy.engine.result import RowProxy
+import sys
+
NUM_FIELDS = 10
NUM_RECORDS = 1000
@@ -80,7 +83,6 @@ class RowProxyTest(fixtures.TestBase):
__requires__ = 'cpython',
def _rowproxy_fixture(self, keys, processors, row):
- from sqlalchemy.engine.result import RowProxy
class MockMeta(object):
def __init__(self):
pass
@@ -96,7 +98,6 @@ class RowProxyTest(fixtures.TestBase):
return RowProxy(metadata, row, processors, keymap)
def _test_getitem_value_refcounts(self, seq_factory):
- import sys
col1, col2 = object(), object()
def proc1(value):
return value
diff --git a/test/base/test_events.py b/test/base/test_events.py
index 7cfb5fa7d..20bfa62ff 100644
--- a/test/base/test_events.py
+++ b/test/base/test_events.py
@@ -5,6 +5,8 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, \
from sqlalchemy import event, exc
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.util import gc_collect
+from sqlalchemy.testing.mock import Mock, call
+
class EventsTest(fixtures.TestBase):
"""Test class- and instance-level event registration."""
@@ -190,7 +192,7 @@ class ClsLevelListenTest(fixtures.TestBase):
def test_lis_subcalss_lis(self):
@event.listens_for(self.TargetOne, "event_one")
def handler1(x, y):
- print('handler1')
+ pass
class SubTarget(self.TargetOne):
pass
@@ -207,7 +209,7 @@ class ClsLevelListenTest(fixtures.TestBase):
def test_lis_multisub_lis(self):
@event.listens_for(self.TargetOne, "event_one")
def handler1(x, y):
- print('handler1')
+ pass
class SubTarget(self.TargetOne):
pass
@@ -411,12 +413,8 @@ class ListenOverrideTest(fixtures.TestBase):
event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
def test_listen_override(self):
- result = []
- def listen_one(x):
- result.append(x)
-
- def listen_two(x, y):
- result.append((x, y))
+ listen_one = Mock()
+ listen_two = Mock()
event.listen(self.Target, "event_one", listen_one, add=True)
event.listen(self.Target, "event_one", listen_two)
@@ -425,10 +423,13 @@ class ListenOverrideTest(fixtures.TestBase):
t1.dispatch.event_one(5, 7)
t1.dispatch.event_one(10, 5)
- eq_(result,
- [
- 12, (5, 7), 15, (10, 5)
- ]
+ eq_(
+ listen_one.mock_calls,
+ [call(12), call(15)]
+ )
+ eq_(
+ listen_two.mock_calls,
+ [call(5, 7), call(10, 5)]
)
class PropagateTest(fixtures.TestBase):
@@ -446,12 +447,8 @@ class PropagateTest(fixtures.TestBase):
def test_propagate(self):
- result = []
- def listen_one(target, arg):
- result.append((target, arg))
-
- def listen_two(target, arg):
- result.append((target, arg))
+ listen_one = Mock()
+ listen_two = Mock()
t1 = self.Target()
@@ -464,7 +461,15 @@ class PropagateTest(fixtures.TestBase):
t2.dispatch.event_one(t2, 1)
t2.dispatch.event_two(t2, 2)
- eq_(result, [(t2, 1)])
+
+ eq_(
+ listen_one.mock_calls,
+ [call(t2, 1)]
+ )
+ eq_(
+ listen_two.mock_calls,
+ []
+ )
class JoinTest(fixtures.TestBase):
def setUp(self):
@@ -497,12 +502,6 @@ class JoinTest(fixtures.TestBase):
if 'dispatch' in cls.__dict__:
event._remove_dispatcher(cls.__dict__['dispatch'].events)
- def _listener(self):
- canary = []
- def listen(target, arg):
- canary.append((target, arg))
- return listen, canary
-
def test_neither(self):
element = self.TargetFactory().create()
element.run_event(1)
@@ -510,22 +509,22 @@ class JoinTest(fixtures.TestBase):
element.run_event(3)
def test_parent_class_only(self):
- _listener, canary = self._listener()
+ l1 = Mock()
- event.listen(self.TargetFactory, "event_one", _listener)
+ event.listen(self.TargetFactory, "event_one", l1)
element = self.TargetFactory().create()
element.run_event(1)
element.run_event(2)
element.run_event(3)
eq_(
- canary,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_class_child_class(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetFactory, "event_one", l1)
event.listen(self.TargetElement, "event_one", l2)
@@ -535,17 +534,17 @@ class JoinTest(fixtures.TestBase):
element.run_event(2)
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_class_child_instance_apply_after(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetFactory, "event_one", l1)
element = self.TargetFactory().create()
@@ -557,17 +556,17 @@ class JoinTest(fixtures.TestBase):
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 2), call(element, 3)]
)
def test_parent_class_child_instance_apply_before(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetFactory, "event_one", l1)
element = self.TargetFactory().create()
@@ -579,17 +578,17 @@ class JoinTest(fixtures.TestBase):
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_instance_child_class_apply_before(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetElement, "event_one", l2)
@@ -603,17 +602,18 @@ class JoinTest(fixtures.TestBase):
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
+
def test_parent_instance_child_class_apply_after(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetElement, "event_one", l2)
@@ -632,18 +632,16 @@ class JoinTest(fixtures.TestBase):
# this can be changed to be "live" at the cost
# of performance.
eq_(
- c1,
- []
- #(element, 2), (element, 3)]
+ l1.mock_calls, []
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_instance_child_instance_apply_before(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
factory = self.TargetFactory()
event.listen(factory, "event_one", l1)
@@ -656,16 +654,16 @@ class JoinTest(fixtures.TestBase):
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_events_child_no_events(self):
- l1, c1 = self._listener()
+ l1 = Mock()
factory = self.TargetFactory()
event.listen(self.TargetElement, "event_one", l1)
@@ -676,6 +674,6 @@ class JoinTest(fixtures.TestBase):
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index 86ce91dc9..1fc239cb7 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -16,6 +16,7 @@ from sqlalchemy import exc, schema
from sqlalchemy.dialects.postgresql import base as postgresql
import logging
import logging.handlers
+from sqlalchemy.testing.mock import Mock
class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
@@ -37,18 +38,12 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
'The JDBC driver handles the version parsing')
def test_version_parsing(self):
-
- class MockConn(object):
-
- def __init__(self, res):
- self.res = res
-
- def execute(self, str):
- return self
-
- def scalar(self):
- return self.res
-
+ def mock_conn(res):
+ return Mock(
+ execute=Mock(
+ return_value=Mock(scalar=Mock(return_value=res))
+ )
+ )
for string, version in \
[('PostgreSQL 8.3.8 on i686-redhat-linux-gnu, compiled by '
@@ -59,7 +54,7 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
('EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, '
'compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), '
'64-bit', (9, 1, 2))]:
- eq_(testing.db.dialect._get_server_version_info(MockConn(string)),
+ eq_(testing.db.dialect._get_server_version_info(mock_conn(string)),
version)
@testing.only_on('postgresql+psycopg2', 'psycopg2-specific feature')
diff --git a/test/dialect/test_mxodbc.py b/test/dialect/test_mxodbc.py
index 32cad4168..e46de9149 100644
--- a/test/dialect/test_mxodbc.py
+++ b/test/dialect/test_mxodbc.py
@@ -2,75 +2,48 @@ from sqlalchemy import *
from sqlalchemy.testing import eq_
from sqlalchemy.testing import engines
from sqlalchemy.testing import fixtures
-
-# TODO: we should probably build mock bases for
-# these to share with test_reconnect, test_parseconnect
-class MockDBAPI(object):
- paramstyle = 'qmark'
- def __init__(self):
- self.log = []
- def connect(self, *args, **kwargs):
- return MockConnection(self)
-
-class MockConnection(object):
- def __init__(self, parent):
- self.parent = parent
- def cursor(self):
- return MockCursor(self)
- def close(self):
- pass
- def rollback(self):
- pass
- def commit(self):
- pass
-
-class MockCursor(object):
- description = None
- rowcount = None
- def __init__(self, parent):
- self.parent = parent
- def execute(self, *args, **kwargs):
- if kwargs.get('direct', False):
- self.executedirect()
- else:
- self.parent.parent.log.append('execute')
- def executedirect(self, *args, **kwargs):
- self.parent.parent.log.append('executedirect')
- def close(self):
- pass
+from sqlalchemy.testing.mock import Mock
+
+def mock_dbapi():
+ return Mock(paramstyle='qmark',
+ connect=Mock(
+ return_value=Mock(
+ cursor=Mock(
+ return_value=Mock(
+ description=None,
+ rowcount=None)
+ )
+ )
+ )
+ )
class MxODBCTest(fixtures.TestBase):
def test_native_odbc_execute(self):
t1 = Table('t1', MetaData(), Column('c1', Integer))
- dbapi = MockDBAPI()
+ dbapi = mock_dbapi()
+
engine = engines.testing_engine('mssql+mxodbc://localhost',
options={'module': dbapi, '_initialize': False})
conn = engine.connect()
# crud: uses execute
-
conn.execute(t1.insert().values(c1='foo'))
conn.execute(t1.delete().where(t1.c.c1 == 'foo'))
- conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar'
- ))
+ conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar'))
# select: uses executedirect
-
conn.execute(t1.select())
# manual flagging
-
conn.execution_options(native_odbc_execute=True).\
execute(t1.select())
conn.execution_options(native_odbc_execute=False).\
- execute(t1.insert().values(c1='foo'
- ))
- eq_(dbapi.log, [
- 'executedirect',
- 'executedirect',
- 'executedirect',
- 'executedirect',
- 'execute',
- 'executedirect',
- ])
+ execute(t1.insert().values(c1='foo'))
+
+ eq_(
+ [c[2] for c in
+ dbapi.connect.return_value.cursor.return_value.execute.mock_calls],
+ [{'direct': True}, {'direct': True}, {'direct': True},
+ {'direct': True}, {'direct': False}, {'direct': True}]
+ )
diff --git a/test/engine/test_ddlemit.py b/test/engine/test_ddlemit.py
index deaf09cf7..e773d0ced 100644
--- a/test/engine/test_ddlemit.py
+++ b/test/engine/test_ddlemit.py
@@ -3,28 +3,19 @@ from sqlalchemy.engine.ddl import SchemaGenerator, SchemaDropper
from sqlalchemy.engine import default
from sqlalchemy import MetaData, Table, Column, Integer, Sequence
from sqlalchemy import schema
+from sqlalchemy.testing.mock import Mock
class EmitDDLTest(fixtures.TestBase):
def _mock_connection(self, item_exists):
- _canary = []
+ def has_item(connection, name, schema):
+ return item_exists(name)
- class MockDialect(default.DefaultDialect):
- supports_sequences = True
-
- def has_table(self, connection, name, schema):
- return item_exists(name)
-
- def has_sequence(self, connection, name, schema):
- return item_exists(name)
-
- class MockConnection(object):
- dialect = MockDialect()
- canary = _canary
-
- def execute(self, item):
- _canary.append(item)
-
- return MockConnection()
+ return Mock(dialect=Mock(
+ supports_sequences=True,
+ has_table=Mock(side_effect=has_item),
+ has_sequence=Mock(side_effect=has_item)
+ )
+ )
def _mock_create_fixture(self, checkfirst, tables,
item_exists=lambda item: False):
@@ -176,7 +167,8 @@ class EmitDDLTest(fixtures.TestBase):
def _assert_ddl(self, ddl_cls, elements, generator, argument):
generator.traverse_single(argument)
- for c in generator.connection.canary:
+ for call_ in generator.connection.execute.mock_calls:
+ c = call_[1][0]
assert isinstance(c, ddl_cls)
assert c.element in elements, "element %r was not expected"\
% c.element
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index 1c577730b..9795e4c10 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -19,6 +19,8 @@ from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam
from sqlalchemy.engine import result as _result, default
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.mock import Mock, call
+
users, metadata, users_autoinc = None, None, None
class ExecuteTest(fixtures.TestBase):
@@ -455,20 +457,22 @@ class ConvenienceExecuteTest(fixtures.TablesTest):
def test_transaction_engine_ctx_begin_fails(self):
engine = engines.testing_engine()
- class MockConnection(Connection):
- closed = False
- def begin(self):
- raise Exception("boom")
-
- def close(self):
- MockConnection.closed = True
- engine._connection_cls = MockConnection
- fn = self._trans_fn()
+
+ mock_connection = Mock(
+ return_value=Mock(
+ begin=Mock(side_effect=Exception("boom"))
+ )
+ )
+ engine._connection_cls = mock_connection
assert_raises(
Exception,
engine.begin
)
- assert MockConnection.closed
+
+ eq_(
+ mock_connection.return_value.close.mock_calls,
+ [call()]
+ )
def test_transaction_engine_ctx_rollback(self):
fn = self._trans_rollback_fn()
diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py
index 73bdc76c4..106bd0782 100644
--- a/test/engine/test_parseconnect.py
+++ b/test/engine/test_parseconnect.py
@@ -7,6 +7,8 @@ from sqlalchemy.engine.default import DefaultDialect
import sqlalchemy as tsa
from sqlalchemy.testing import fixtures
from sqlalchemy import testing
+from sqlalchemy.testing.mock import Mock
+
class ParseConnectTest(fixtures.TestBase):
def test_rfc1738(self):
@@ -250,20 +252,17 @@ pool_timeout=10
every backend.
"""
- # pretend pysqlite throws the
- # "Cannot operate on a closed database." error
- # on connect. IRL we'd be getting Oracle's "shutdown in progress"
e = create_engine('sqlite://')
sqlite3 = e.dialect.dbapi
- class ThrowOnConnect(MockDBAPI):
- dbapi = sqlite3
- Error = sqlite3.Error
- ProgrammingError = sqlite3.ProgrammingError
- def connect(self, *args, **kw):
- raise sqlite3.ProgrammingError("Cannot operate on a closed database.")
+
+ dbapi = MockDBAPI()
+ dbapi.Error = sqlite3.Error,
+ dbapi.ProgrammingError = sqlite3.ProgrammingError
+ dbapi.connect = Mock(side_effect=sqlite3.ProgrammingError(
+ "Cannot operate on a closed database."))
try:
- create_engine('sqlite://', module=ThrowOnConnect()).connect()
+ create_engine('sqlite://', module=dbapi).connect()
assert False
except tsa.exc.DBAPIError as de:
assert de.connection_invalidated
@@ -354,36 +353,23 @@ class MockDialect(DefaultDialect):
def dbapi(cls, **kw):
return MockDBAPI()
-class MockDBAPI(object):
- version_info = sqlite_version_info = 99, 9, 9
- sqlite_version = '99.9.9'
-
- def __init__(self, **kwargs):
- self.kwargs = kwargs
- self.paramstyle = 'named'
-
- def connect(self, *args, **kwargs):
- for k in self.kwargs:
+def MockDBAPI(**assert_kwargs):
+ connection = Mock(get_server_version_info=Mock(return_value='5.0'))
+ def connect(*args, **kwargs):
+ for k in assert_kwargs:
assert k in kwargs, 'key %s not present in dictionary' % k
- assert kwargs[k] == self.kwargs[k], \
- 'value %s does not match %s' % (kwargs[k],
- self.kwargs[k])
- return MockConnection()
-
-
-class MockConnection(object):
- def get_server_info(self):
- return '5.0'
-
- def close(self):
- pass
-
- def cursor(self):
- return MockCursor()
-
-class MockCursor(object):
- def close(self):
- pass
+ eq_(
+ kwargs[k], assert_kwargs[k]
+ )
+ return connection
+
+ return Mock(
+ sqlite_version_info=(99, 9, 9,),
+ version_info=(99, 9, 9,),
+ sqlite_version='99.9.9',
+ paramstyle='named',
+ connect=Mock(side_effect=connect)
+ )
mock_dbapi = MockDBAPI()
mock_sqlite_dbapi = msd = MockDBAPI()
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
index 583978465..981df6dd0 100644
--- a/test/engine/test_pool.py
+++ b/test/engine/test_pool.py
@@ -4,37 +4,21 @@ from sqlalchemy import pool, select, event
import sqlalchemy as tsa
from sqlalchemy import testing
from sqlalchemy.testing.util import gc_collect, lazy_gc
-from sqlalchemy.testing import eq_, assert_raises
+from sqlalchemy.testing import eq_, assert_raises, is_not_
from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing import fixtures
-mcid = 1
-class MockDBAPI(object):
- throw_error = False
- def connect(self, *args, **kwargs):
- if self.throw_error:
- raise Exception("couldnt connect !")
- delay = kwargs.pop('delay', 0)
- if delay:
- time.sleep(delay)
- return MockConnection()
-class MockConnection(object):
- closed = False
- def __init__(self):
- global mcid
- self.id = mcid
- mcid += 1
- def close(self):
- self.closed = True
- def rollback(self):
- pass
- def cursor(self):
- return MockCursor()
-class MockCursor(object):
- def execute(self, *args, **kw):
- pass
- def close(self):
- pass
+from sqlalchemy.testing.mock import Mock, call
+
+def MockDBAPI():
+ def cursor():
+ while True:
+ yield Mock()
+ def connect():
+ while True:
+ yield Mock(cursor=Mock(side_effect=cursor()))
+
+ return Mock(connect=Mock(side_effect=connect()))
class PoolTestBase(fixtures.TestBase):
def setup(self):
@@ -71,11 +55,9 @@ class PoolTest(PoolTestBase):
assert c4 is not c5
def test_manager_with_key(self):
- class NoKws(object):
- def connect(self, arg):
- return MockConnection()
- manager = pool.manage(NoKws(), use_threadlocal=True)
+ dbapi = MockDBAPI()
+ manager = pool.manage(dbapi, use_threadlocal=True)
c1 = manager.connect('foo.db', sa_pool_key="a")
c2 = manager.connect('foo.db', sa_pool_key="b")
@@ -83,9 +65,14 @@ class PoolTest(PoolTestBase):
assert c1.cursor() is not None
assert c1 is not c2
- assert c1 is c3
-
+ assert c1 is c3
+ eq_(dbapi.connect.mock_calls,
+ [
+ call("foo.db"),
+ call("foo.db"),
+ ]
+ )
def test_bad_args(self):
@@ -127,7 +114,7 @@ class PoolTest(PoolTestBase):
p = cls(creator=mock_dbapi.connect)
conn = p.connect()
conn.close()
- mock_dbapi.throw_error = True
+ mock_dbapi.connect.side_effect = Exception("error!")
p.dispose()
p.recreate()
@@ -211,9 +198,9 @@ class PoolTest(PoolTestBase):
self.assert_('foo2' in c.info)
c2 = p.connect()
- self.assert_(c.connection is not c2.connection)
- self.assert_(not c2.info)
- self.assert_('foo2' in c.info)
+ is_not_(c.connection, c2.connection)
+ assert not c2.info
+ assert 'foo2' in c.info
class PoolDialectTest(PoolTestBase):
@@ -945,19 +932,24 @@ class QueuePoolTest(PoolTestBase):
def test_dispose_closes_pooled(self):
dbapi = MockDBAPI()
- def creator():
- return dbapi.connect()
- p = pool.QueuePool(creator=creator,
+ p = pool.QueuePool(creator=dbapi.connect,
pool_size=2, timeout=None,
max_overflow=0)
c1 = p.connect()
c2 = p.connect()
- conns = [c1.connection, c2.connection]
+ c1_con = c1.connection
+ c2_con = c2.connection
+
c1.close()
- eq_([c.closed for c in conns], [False, False])
+
+ eq_(c1_con.close.call_count, 0)
+ eq_(c2_con.close.call_count, 0)
+
p.dispose()
- eq_([c.closed for c in conns], [True, False])
+
+ eq_(c1_con.close.call_count, 1)
+ eq_(c2_con.close.call_count, 0)
# currently, if a ConnectionFairy is closed
# after the pool has been disposed, there's no
@@ -965,11 +957,12 @@ class QueuePoolTest(PoolTestBase):
# immediately - it just gets returned to the
# pool normally...
c2.close()
- eq_([c.closed for c in conns], [True, False])
+ eq_(c1_con.close.call_count, 1)
+ eq_(c2_con.close.call_count, 0)
# ...and that's the one we'll get back next.
c3 = p.connect()
- assert c3.connection is conns[1]
+ assert c3.connection is c2_con
def test_no_overflow(self):
self._test_overflow(40, 0)
@@ -1010,12 +1003,21 @@ class QueuePoolTest(PoolTestBase):
return c
for j in range(5):
+ # open 4 conns at a time. each time this
+ # will yield two pooled connections + two
+ # overflow connections.
conns = [_conn() for i in range(4)]
for c in conns:
c.close()
- still_opened = len([c for c in strong_refs if not c.closed])
- eq_(still_opened, 2)
+ # doing that for a total of 5 times yields
+ # ten overflow connections closed plus the
+ # two pooled connections unclosed.
+
+ eq_(
+ set([c.close.call_count for c in strong_refs]),
+ set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0])
+ )
@testing.requires.predictable_gc
def test_weakref_kaboom(self):
@@ -1108,18 +1110,30 @@ class QueuePoolTest(PoolTestBase):
dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
c1 = p.connect()
c1.detach()
- c_id = c1.connection.id
- c2 = p.connect()
- assert c2.connection.id != c1.connection.id
- dbapi.raise_error = True
- c2.invalidate()
- c2 = None
c2 = p.connect()
- assert c2.connection.id != c1.connection.id
- con = c1.connection
- assert not con.closed
+ eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")])
+
+ c1_con = c1.connection
+ assert c1_con is not None
+ eq_(c1_con.close.call_count, 0)
c1.close()
- assert con.closed
+ eq_(c1_con.close.call_count, 1)
+
+ def test_detach_via_invalidate(self):
+ dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
+
+ c1 = p.connect()
+ c1_con = c1.connection
+ c1.invalidate()
+ assert c1.connection is None
+ eq_(c1_con.close.call_count, 1)
+
+ c2 = p.connect()
+ assert c2.connection is not c1_con
+ c2_con = c2.connection
+
+ c2.close()
+ eq_(c2_con.close.call_count, 0)
def test_threadfairy(self):
p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True)
@@ -1141,8 +1155,13 @@ class SingletonThreadPoolTest(PoolTestBase):
been called."""
dbapi = MockDBAPI()
- p = pool.SingletonThreadPool(creator=dbapi.connect,
- pool_size=3)
+
+ lock = threading.Lock()
+ def creator():
+ # the mock iterator isn't threadsafe...
+ with lock:
+ return dbapi.connect()
+ p = pool.SingletonThreadPool(creator=creator, pool_size=3)
if strong_refs:
sr = set()
@@ -1172,7 +1191,7 @@ class SingletonThreadPoolTest(PoolTestBase):
assert len(p._all_conns) == 3
if strong_refs:
- still_opened = len([c for c in sr if not c.closed])
+ still_opened = len([c for c in sr if not c.close.call_count])
eq_(still_opened, 3)
class AssertionPoolTest(PoolTestBase):
@@ -1198,17 +1217,19 @@ class NullPoolTest(PoolTestBase):
dbapi = MockDBAPI()
p = pool.NullPool(creator=lambda: dbapi.connect('foo.db'))
c1 = p.connect()
- c_id = c1.connection.id
+
c1.close()
c1 = None
c1 = p.connect()
- dbapi.raise_error = True
c1.invalidate()
c1 = None
c1 = p.connect()
- assert c1.connection.id != c_id
+ dbapi.connect.assert_has_calls([
+ call('foo.db'),
+ call('foo.db')],
+ any_order=True)
class StaticPoolTest(PoolTestBase):
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index ee3ff1459..86003bec6 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -1,7 +1,6 @@
from sqlalchemy.testing import eq_, assert_raises, assert_raises_message
import time
-import weakref
-from sqlalchemy import select, MetaData, Integer, String, pool, create_engine
+from sqlalchemy import select, MetaData, Integer, String, create_engine, pool
from sqlalchemy.testing.schema import Table, Column
import sqlalchemy as tsa
from sqlalchemy import testing
@@ -10,6 +9,8 @@ from sqlalchemy.testing.util import gc_collect
from sqlalchemy import exc, util
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.engines import testing_engine
+from sqlalchemy.testing import is_not_
+from sqlalchemy.testing.mock import Mock, call
class MockError(Exception):
pass
@@ -17,93 +18,103 @@ class MockError(Exception):
class MockDisconnect(MockError):
pass
-class MockDBAPI(object):
- def __init__(self):
- self.paramstyle = 'named'
- self.connections = weakref.WeakKeyDictionary()
- def connect(self, *args, **kwargs):
- return MockConnection(self)
- def shutdown(self, explode='execute'):
- for c in self.connections:
- c.explode = explode
- Error = MockError
-
-class MockConnection(object):
- def __init__(self, dbapi):
- dbapi.connections[self] = True
- self.explode = ""
- def rollback(self):
- if self.explode == 'rollback':
+def mock_connection():
+ def mock_cursor():
+ def execute(*args, **kwargs):
+ if conn.explode == 'execute':
+ raise MockDisconnect("Lost the DB connection on execute")
+ elif conn.explode in ('execute_no_disconnect', ):
+ raise MockError(
+ "something broke on execute but we didn't lose the connection")
+ elif conn.explode in ('rollback', 'rollback_no_disconnect'):
+ raise MockError(
+ "something broke on execute but we didn't lose the connection")
+ elif args and "SELECT" in args[0]:
+ cursor.description = [('foo', None, None, None, None, None)]
+ else:
+ return
+
+ def close():
+ cursor.fetchall = cursor.fetchone = \
+ Mock(side_effect=MockError("cursor closed"))
+ cursor = Mock(
+ execute=Mock(side_effect=execute),
+ close=Mock(side_effect=close)
+ )
+ return cursor
+
+ def cursor():
+ while True:
+ yield mock_cursor()
+
+ def rollback():
+ if conn.explode == 'rollback':
raise MockDisconnect("Lost the DB connection on rollback")
- if self.explode == 'rollback_no_disconnect':
+ if conn.explode == 'rollback_no_disconnect':
raise MockError(
"something broke on rollback but we didn't lose the connection")
else:
return
- def commit(self):
- pass
- def cursor(self):
- return MockCursor(self)
- def close(self):
- pass
-
-class MockCursor(object):
- def __init__(self, parent):
- self.explode = parent.explode
- self.description = ()
- self.closed = False
- def execute(self, *args, **kwargs):
- if self.explode == 'execute':
- raise MockDisconnect("Lost the DB connection on execute")
- elif self.explode in ('execute_no_disconnect', ):
- raise MockError(
- "something broke on execute but we didn't lose the connection")
- elif self.explode in ('rollback', 'rollback_no_disconnect'):
- raise MockError(
- "something broke on execute but we didn't lose the connection")
- elif args and "select" in args[0]:
- self.description = [('foo', None, None, None, None, None)]
- else:
- return
- def fetchall(self):
- if self.closed:
- raise MockError("cursor closed")
- return []
- def fetchone(self):
- if self.closed:
- raise MockError("cursor closed")
- return None
- def close(self):
- self.closed = True
-
-db, dbapi = None, None
+
+ conn = Mock(
+ rollback=Mock(side_effect=rollback),
+ cursor=Mock(side_effect=cursor())
+ )
+ return conn
+
+def MockDBAPI():
+ connections = []
+ def connect():
+ while True:
+ conn = mock_connection()
+ connections.append(conn)
+ yield conn
+
+ def shutdown(explode='execute'):
+ for c in connections:
+ c.explode = explode
+
+ def dispose():
+ for c in connections:
+ c.explode = None
+ connections[:] = []
+
+ return Mock(
+ connect=Mock(side_effect=connect()),
+ shutdown=Mock(side_effect=shutdown),
+ dispose=Mock(side_effect=dispose),
+ paramstyle='named',
+ connections=connections,
+ Error=MockError
+ )
+
+
class MockReconnectTest(fixtures.TestBase):
def setup(self):
- global db, dbapi
- dbapi = MockDBAPI()
+ self.dbapi = MockDBAPI()
- # note - using straight create_engine here
- # since we are testing gc
- db = create_engine(
+ self.db = testing_engine(
'postgresql://foo:bar@localhost/test',
- module=dbapi, _initialize=False)
+ options=dict(module=self.dbapi, _initialize=False))
+ self.mock_connect = call(host='localhost', password='bar',
+ user='foo', database='test')
# monkeypatch disconnect checker
- db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect)
+ self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect)
def teardown(self):
- db.dispose()
+ self.dbapi.dispose()
def test_reconnect(self):
"""test that an 'is_disconnect' condition will invalidate the
connection, and additionally dispose the previous connection
pool and recreate."""
- pid = id(db.pool)
+ db_pool = self.db.pool
# make a connection
- conn = db.connect()
+ conn = self.db.connect()
# connection works
@@ -112,21 +123,20 @@ class MockReconnectTest(fixtures.TestBase):
# create a second connection within the pool, which we'll ensure
# also goes away
- conn2 = db.connect()
+ conn2 = self.db.connect()
conn2.close()
# two connections opened total now
- assert len(dbapi.connections) == 2
+ assert len(self.dbapi.connections) == 2
# set it to fail
- dbapi.shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError:
- pass
+ self.dbapi.shutdown()
+ assert_raises(
+ tsa.exc.DBAPIError,
+ conn.execute, select([1])
+ )
# assert was invalidated
@@ -136,31 +146,38 @@ class MockReconnectTest(fixtures.TestBase):
# close shouldnt break
conn.close()
- assert id(db.pool) != pid
+ is_not_(self.db.pool, db_pool)
# ensure all connections closed (pool was recycled)
- gc_collect()
- assert len(dbapi.connections) == 0
- conn = db.connect()
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()], [call()]]
+ )
+
+ conn = self.db.connect()
conn.execute(select([1]))
conn.close()
- assert len(dbapi.connections) == 1
+
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()], [call()], []]
+ )
def test_invalidate_trans(self):
- conn = db.connect()
+ conn = self.db.connect()
trans = conn.begin()
- dbapi.shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError:
- pass
+ self.dbapi.shutdown()
- # assert was invalidated
+ assert_raises(
+ tsa.exc.DBAPIError,
+ conn.execute, select([1])
+ )
- gc_collect()
- assert len(dbapi.connections) == 0
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()]]
+ )
assert not conn.closed
assert conn.invalidated
assert trans.is_active
@@ -170,28 +187,35 @@ class MockReconnectTest(fixtures.TestBase):
conn.execute, select([1])
)
assert trans.is_active
- try:
- trans.commit()
- assert False
- except tsa.exc.InvalidRequestError as e:
- assert str(e) \
- == "Can't reconnect until invalid transaction is "\
- "rolled back"
+
+ assert_raises_message(
+ tsa.exc.InvalidRequestError,
+ "Can't reconnect until invalid transaction is "
+ "rolled back",
+ trans.commit
+ )
+
assert trans.is_active
trans.rollback()
assert not trans.is_active
conn.execute(select([1]))
assert not conn.invalidated
- assert len(dbapi.connections) == 1
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()], []]
+ )
def test_conn_reusable(self):
- conn = db.connect()
+ conn = self.db.connect()
conn.execute(select([1]))
- assert len(dbapi.connections) == 1
+ eq_(
+ self.dbapi.connect.mock_calls,
+ [self.mock_connect]
+ )
- dbapi.shutdown()
+ self.dbapi.shutdown()
assert_raises(
tsa.exc.DBAPIError,
@@ -201,19 +225,24 @@ class MockReconnectTest(fixtures.TestBase):
assert not conn.closed
assert conn.invalidated
- # ensure all connections closed (pool was recycled)
- gc_collect()
- assert len(dbapi.connections) == 0
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()]]
+ )
# test reconnects
conn.execute(select([1]))
assert not conn.invalidated
- assert len(dbapi.connections) == 1
+
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()], []]
+ )
def test_invalidated_close(self):
- conn = db.connect()
+ conn = self.db.connect()
- dbapi.shutdown()
+ self.dbapi.shutdown()
assert_raises(
tsa.exc.DBAPIError,
@@ -230,9 +259,9 @@ class MockReconnectTest(fixtures.TestBase):
)
def test_noreconnect_execute_plus_closewresult(self):
- conn = db.connect(close_with_result=True)
+ conn = self.db.connect(close_with_result=True)
- dbapi.shutdown("execute_no_disconnect")
+ self.dbapi.shutdown("execute_no_disconnect")
# raises error
assert_raises_message(
@@ -245,9 +274,9 @@ class MockReconnectTest(fixtures.TestBase):
assert not conn.invalidated
def test_noreconnect_rollback_plus_closewresult(self):
- conn = db.connect(close_with_result=True)
+ conn = self.db.connect(close_with_result=True)
- dbapi.shutdown("rollback_no_disconnect")
+ self.dbapi.shutdown("rollback_no_disconnect")
# raises error
assert_raises_message(
@@ -266,13 +295,13 @@ class MockReconnectTest(fixtures.TestBase):
)
def test_reconnect_on_reentrant(self):
- conn = db.connect()
+ conn = self.db.connect()
conn.execute(select([1]))
- assert len(dbapi.connections) == 1
+ assert len(self.dbapi.connections) == 1
- dbapi.shutdown("rollback")
+ self.dbapi.shutdown("rollback")
# raises error
assert_raises_message(
@@ -285,9 +314,9 @@ class MockReconnectTest(fixtures.TestBase):
assert conn.invalidated
def test_reconnect_on_reentrant_plus_closewresult(self):
- conn = db.connect(close_with_result=True)
+ conn = self.db.connect(close_with_result=True)
- dbapi.shutdown("rollback")
+ self.dbapi.shutdown("rollback")
# raises error
assert_raises_message(
@@ -306,10 +335,11 @@ class MockReconnectTest(fixtures.TestBase):
)
def test_check_disconnect_no_cursor(self):
- conn = db.connect()
- result = conn.execute("select 1")
+ conn = self.db.connect()
+ result = conn.execute(select([1]))
result.cursor.close()
conn.close()
+
assert_raises_message(
tsa.exc.DBAPIError,
"cursor closed",
@@ -319,60 +349,59 @@ class MockReconnectTest(fixtures.TestBase):
class CursorErrTest(fixtures.TestBase):
def setup(self):
- global db, dbapi
-
- class MDBAPI(MockDBAPI):
- def connect(self, *args, **kwargs):
- return MConn(self)
-
- class MConn(MockConnection):
- def cursor(self):
- return MCursor(self)
+ def MockDBAPI():
+ def cursor():
+ while True:
+ yield Mock(
+ description=[],
+ close=Mock(side_effect=Exception("explode")))
+ def connect():
+ while True:
+ yield Mock(cursor=Mock(side_effect=cursor()))
+
+ return Mock(connect=Mock(side_effect=connect()))
- class MCursor(MockCursor):
- def close(self):
- raise Exception("explode")
-
- dbapi = MDBAPI()
-
- db = testing_engine(
+ dbapi = MockDBAPI()
+ self.db = testing_engine(
'postgresql://foo:bar@localhost/test',
options=dict(module=dbapi, _initialize=False))
def test_cursor_explode(self):
- conn = db.connect()
+ conn = self.db.connect()
result = conn.execute("select foo")
result.close()
conn.close()
def teardown(self):
- db.dispose()
+ self.db.dispose()
+
+
+def _assert_invalidated(fn, *args):
+ try:
+ fn(*args)
+ assert False
+ except tsa.exc.DBAPIError as e:
+ if not e.connection_invalidated:
+ raise
-engine = None
class RealReconnectTest(fixtures.TestBase):
def setup(self):
- global engine
- engine = engines.reconnecting_engine()
+ self.engine = engines.reconnecting_engine()
def teardown(self):
- engine.dispose()
+ self.engine.dispose()
@testing.fails_on('+informixdb',
"Wrong error thrown, fix in informixdb?")
def test_reconnect(self):
- conn = engine.connect()
+ conn = self.engine.connect()
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
- engine.test_shutdown()
+ self.engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError as e:
- if not e.connection_invalidated:
- raise
+ _assert_invalidated(conn.execute, select([1]))
assert not conn.closed
assert conn.invalidated
@@ -382,13 +411,9 @@ class RealReconnectTest(fixtures.TestBase):
assert not conn.invalidated
# one more time
- engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError as e:
- if not e.connection_invalidated:
- raise
+ self.engine.test_shutdown()
+ _assert_invalidated(conn.execute, select([1]))
+
assert conn.invalidated
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.invalidated
@@ -396,30 +421,22 @@ class RealReconnectTest(fixtures.TestBase):
conn.close()
def test_multiple_invalidate(self):
- c1 = engine.connect()
- c2 = engine.connect()
+ c1 = self.engine.connect()
+ c2 = self.engine.connect()
eq_(c1.execute(select([1])).scalar(), 1)
- p1 = engine.pool
- engine.test_shutdown()
+ p1 = self.engine.pool
+ self.engine.test_shutdown()
- try:
- c1.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError as e:
- assert e.connection_invalidated
+ _assert_invalidated(c1.execute, select([1]))
- p2 = engine.pool
+ p2 = self.engine.pool
- try:
- c2.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError as e:
- assert e.connection_invalidated
+ _assert_invalidated(c2.execute, select([1]))
# pool isn't replaced
- assert engine.pool is p2
+ assert self.engine.pool is p2
def test_ensure_is_disconnect_gets_connection(self):
@@ -430,37 +447,37 @@ class RealReconnectTest(fixtures.TestBase):
# though MySQLdb we get a non-working cursor.
# assert cursor is None
- engine.dialect.is_disconnect = is_disconnect
- conn = engine.connect()
- engine.test_shutdown()
+ self.engine.dialect.is_disconnect = is_disconnect
+ conn = self.engine.connect()
+ self.engine.test_shutdown()
assert_raises(
tsa.exc.DBAPIError,
conn.execute, select([1])
)
def test_rollback_on_invalid_plain(self):
- conn = engine.connect()
+ conn = self.engine.connect()
trans = conn.begin()
conn.invalidate()
trans.rollback()
@testing.requires.two_phase_transactions
def test_rollback_on_invalid_twophase(self):
- conn = engine.connect()
+ conn = self.engine.connect()
trans = conn.begin_twophase()
conn.invalidate()
trans.rollback()
@testing.requires.savepoints
def test_rollback_on_invalid_savepoint(self):
- conn = engine.connect()
+ conn = self.engine.connect()
trans = conn.begin()
trans2 = conn.begin_nested()
conn.invalidate()
trans2.rollback()
def test_invalidate_twice(self):
- conn = engine.connect()
+ conn = self.engine.connect()
conn.invalidate()
conn.invalidate()
@@ -503,12 +520,7 @@ class RealReconnectTest(fixtures.TestBase):
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError as e:
- if not e.connection_invalidated:
- raise
+ _assert_invalidated(conn.execute, select([1]))
assert not conn.closed
assert conn.invalidated
eq_(conn.execute(select([1])).scalar(), 1)
@@ -517,37 +529,27 @@ class RealReconnectTest(fixtures.TestBase):
@testing.fails_on('+informixdb',
"Wrong error thrown, fix in informixdb?")
def test_close(self):
- conn = engine.connect()
+ conn = self.engine.connect()
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
- engine.test_shutdown()
+ self.engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError as e:
- if not e.connection_invalidated:
- raise
+ _assert_invalidated(conn.execute, select([1]))
conn.close()
- conn = engine.connect()
+ conn = self.engine.connect()
eq_(conn.execute(select([1])).scalar(), 1)
@testing.fails_on('+informixdb',
"Wrong error thrown, fix in informixdb?")
def test_with_transaction(self):
- conn = engine.connect()
+ conn = self.engine.connect()
trans = conn.begin()
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
- engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError as e:
- if not e.connection_invalidated:
- raise
+ self.engine.test_shutdown()
+ _assert_invalidated(conn.execute, select([1]))
assert not conn.closed
assert conn.invalidated
assert trans.is_active
@@ -558,13 +560,11 @@ class RealReconnectTest(fixtures.TestBase):
conn.execute, select([1])
)
assert trans.is_active
- try:
- trans.commit()
- assert False
- except tsa.exc.InvalidRequestError as e:
- assert str(e) \
- == "Can't reconnect until invalid transaction is "\
- "rolled back"
+ assert_raises_message(
+ tsa.exc.InvalidRequestError,
+ "Can't reconnect until invalid transaction is rolled back",
+ trans.commit
+ )
assert trans.is_active
trans.rollback()
assert not trans.is_active
@@ -602,23 +602,21 @@ class RecycleTest(fixtures.TestBase):
eq_(conn.execute(select([1])).scalar(), 1)
conn.close()
-meta, table, engine = None, None, None
class InvalidateDuringResultTest(fixtures.TestBase):
def setup(self):
- global meta, table, engine
- engine = engines.reconnecting_engine()
- meta = MetaData(engine)
- table = Table('sometable', meta,
+ self.engine = engines.reconnecting_engine()
+ self.meta = MetaData(self.engine)
+ table = Table('sometable', self.meta,
Column('id', Integer, primary_key=True),
Column('name', String(50)))
- meta.create_all()
+ self.meta.create_all()
table.insert().execute(
- [{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
+ [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)]
)
def teardown(self):
- meta.drop_all()
- engine.dispose()
+ self.meta.drop_all()
+ self.engine.dispose()
@testing.fails_if([
'+mysqlconnector', '+mysqldb',
@@ -628,16 +626,11 @@ class InvalidateDuringResultTest(fixtures.TestBase):
@testing.fails_on('+informixdb',
"Wrong error thrown, fix in informixdb?")
def test_invalidate_on_results(self):
- conn = engine.connect()
+ conn = self.engine.connect()
result = conn.execute('select * from sometable')
for x in range(20):
result.fetchone()
- engine.test_shutdown()
- try:
- print('ghost result: %r' % result.fetchone())
- assert False
- except tsa.exc.DBAPIError as e:
- if not e.connection_invalidated:
- raise
+ self.engine.test_shutdown()
+ _assert_invalidated(result.fetchone)
assert conn.invalidated