diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-06-30 18:35:12 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-06-30 18:35:12 -0400 |
commit | b38a76cd1d47cd6b8f1abef30ad7c3aeaa27d537 (patch) | |
tree | 7af1dba9e242c77a248cb2194434aa9bf3ca49b7 /test/engine/test_reconnect.py | |
parent | 715d6cf3d10a71acd7726b7e00c3ff40b4559bc7 (diff) | |
download | sqlalchemy-b38a76cd1d47cd6b8f1abef30ad7c3aeaa27d537.tar.gz |
- replace most explicitly-named test objects called "Mock..." with
actual mock objects from the mock library. I'd like to use mock
for new tests so we might as well use it in obvious places.
- use unittest.mock in py3.3
- changelog
- add a note to README.unittests
- add tests_require in setup.py
- have tests import from sqlalchemy.testing.mock
- apply usage of mock to one of the event tests. we can be using
this approach all over the place.
Diffstat (limited to 'test/engine/test_reconnect.py')
-rw-r--r-- | test/engine/test_reconnect.py | 443 |
1 files changed, 218 insertions, 225 deletions
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index ee3ff1459..86003bec6 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,7 +1,6 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message import time -import weakref -from sqlalchemy import select, MetaData, Integer, String, pool, create_engine +from sqlalchemy import select, MetaData, Integer, String, create_engine, pool from sqlalchemy.testing.schema import Table, Column import sqlalchemy as tsa from sqlalchemy import testing @@ -10,6 +9,8 @@ from sqlalchemy.testing.util import gc_collect from sqlalchemy import exc, util from sqlalchemy.testing import fixtures from sqlalchemy.testing.engines import testing_engine +from sqlalchemy.testing import is_not_ +from sqlalchemy.testing.mock import Mock, call class MockError(Exception): pass @@ -17,93 +18,103 @@ class MockError(Exception): class MockDisconnect(MockError): pass -class MockDBAPI(object): - def __init__(self): - self.paramstyle = 'named' - self.connections = weakref.WeakKeyDictionary() - def connect(self, *args, **kwargs): - return MockConnection(self) - def shutdown(self, explode='execute'): - for c in self.connections: - c.explode = explode - Error = MockError - -class MockConnection(object): - def __init__(self, dbapi): - dbapi.connections[self] = True - self.explode = "" - def rollback(self): - if self.explode == 'rollback': +def mock_connection(): + def mock_cursor(): + def execute(*args, **kwargs): + if conn.explode == 'execute': + raise MockDisconnect("Lost the DB connection on execute") + elif conn.explode in ('execute_no_disconnect', ): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif conn.explode in ('rollback', 'rollback_no_disconnect'): + raise MockError( + "something broke on execute but we didn't lose the connection") + elif args and "SELECT" in args[0]: + cursor.description = [('foo', None, None, None, None, None)] + else: + return + + def close(): + cursor.fetchall = cursor.fetchone = \ + Mock(side_effect=MockError("cursor closed")) + cursor = Mock( + execute=Mock(side_effect=execute), + close=Mock(side_effect=close) + ) + return cursor + + def cursor(): + while True: + yield mock_cursor() + + def rollback(): + if conn.explode == 'rollback': raise MockDisconnect("Lost the DB connection on rollback") - if self.explode == 'rollback_no_disconnect': + if conn.explode == 'rollback_no_disconnect': raise MockError( "something broke on rollback but we didn't lose the connection") else: return - def commit(self): - pass - def cursor(self): - return MockCursor(self) - def close(self): - pass - -class MockCursor(object): - def __init__(self, parent): - self.explode = parent.explode - self.description = () - self.closed = False - def execute(self, *args, **kwargs): - if self.explode == 'execute': - raise MockDisconnect("Lost the DB connection on execute") - elif self.explode in ('execute_no_disconnect', ): - raise MockError( - "something broke on execute but we didn't lose the connection") - elif self.explode in ('rollback', 'rollback_no_disconnect'): - raise MockError( - "something broke on execute but we didn't lose the connection") - elif args and "select" in args[0]: - self.description = [('foo', None, None, None, None, None)] - else: - return - def fetchall(self): - if self.closed: - raise MockError("cursor closed") - return [] - def fetchone(self): - if self.closed: - raise MockError("cursor closed") - return None - def close(self): - self.closed = True - -db, dbapi = None, None + + conn = Mock( + rollback=Mock(side_effect=rollback), + cursor=Mock(side_effect=cursor()) + ) + return conn + +def MockDBAPI(): + connections = [] + def connect(): + while True: + conn = mock_connection() + connections.append(conn) + yield conn + + def shutdown(explode='execute'): + for c in connections: + c.explode = explode + + def dispose(): + for c in connections: + c.explode = None + connections[:] = [] + + return Mock( + connect=Mock(side_effect=connect()), + shutdown=Mock(side_effect=shutdown), + dispose=Mock(side_effect=dispose), + paramstyle='named', + connections=connections, + Error=MockError + ) + + class MockReconnectTest(fixtures.TestBase): def setup(self): - global db, dbapi - dbapi = MockDBAPI() + self.dbapi = MockDBAPI() - # note - using straight create_engine here - # since we are testing gc - db = create_engine( + self.db = testing_engine( 'postgresql://foo:bar@localhost/test', - module=dbapi, _initialize=False) + options=dict(module=self.dbapi, _initialize=False)) + self.mock_connect = call(host='localhost', password='bar', + user='foo', database='test') # monkeypatch disconnect checker - db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect) + self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect) def teardown(self): - db.dispose() + self.dbapi.dispose() def test_reconnect(self): """test that an 'is_disconnect' condition will invalidate the connection, and additionally dispose the previous connection pool and recreate.""" - pid = id(db.pool) + db_pool = self.db.pool # make a connection - conn = db.connect() + conn = self.db.connect() # connection works @@ -112,21 +123,20 @@ class MockReconnectTest(fixtures.TestBase): # create a second connection within the pool, which we'll ensure # also goes away - conn2 = db.connect() + conn2 = self.db.connect() conn2.close() # two connections opened total now - assert len(dbapi.connections) == 2 + assert len(self.dbapi.connections) == 2 # set it to fail - dbapi.shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + self.dbapi.shutdown() + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) # assert was invalidated @@ -136,31 +146,38 @@ class MockReconnectTest(fixtures.TestBase): # close shouldnt break conn.close() - assert id(db.pool) != pid + is_not_(self.db.pool, db_pool) # ensure all connections closed (pool was recycled) - gc_collect() - assert len(dbapi.connections) == 0 - conn = db.connect() + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], [call()]] + ) + + conn = self.db.connect() conn.execute(select([1])) conn.close() - assert len(dbapi.connections) == 1 + + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], [call()], []] + ) def test_invalidate_trans(self): - conn = db.connect() + conn = self.db.connect() trans = conn.begin() - dbapi.shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError: - pass + self.dbapi.shutdown() - # assert was invalidated + assert_raises( + tsa.exc.DBAPIError, + conn.execute, select([1]) + ) - gc_collect() - assert len(dbapi.connections) == 0 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()]] + ) assert not conn.closed assert conn.invalidated assert trans.is_active @@ -170,28 +187,35 @@ class MockReconnectTest(fixtures.TestBase): conn.execute, select([1]) ) assert trans.is_active - try: - trans.commit() - assert False - except tsa.exc.InvalidRequestError as e: - assert str(e) \ - == "Can't reconnect until invalid transaction is "\ - "rolled back" + + assert_raises_message( + tsa.exc.InvalidRequestError, + "Can't reconnect until invalid transaction is " + "rolled back", + trans.commit + ) + assert trans.is_active trans.rollback() assert not trans.is_active conn.execute(select([1])) assert not conn.invalidated - assert len(dbapi.connections) == 1 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], []] + ) def test_conn_reusable(self): - conn = db.connect() + conn = self.db.connect() conn.execute(select([1])) - assert len(dbapi.connections) == 1 + eq_( + self.dbapi.connect.mock_calls, + [self.mock_connect] + ) - dbapi.shutdown() + self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, @@ -201,19 +225,24 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.closed assert conn.invalidated - # ensure all connections closed (pool was recycled) - gc_collect() - assert len(dbapi.connections) == 0 + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()]] + ) # test reconnects conn.execute(select([1])) assert not conn.invalidated - assert len(dbapi.connections) == 1 + + eq_( + [c.close.mock_calls for c in self.dbapi.connections], + [[call()], []] + ) def test_invalidated_close(self): - conn = db.connect() + conn = self.db.connect() - dbapi.shutdown() + self.dbapi.shutdown() assert_raises( tsa.exc.DBAPIError, @@ -230,9 +259,9 @@ class MockReconnectTest(fixtures.TestBase): ) def test_noreconnect_execute_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("execute_no_disconnect") + self.dbapi.shutdown("execute_no_disconnect") # raises error assert_raises_message( @@ -245,9 +274,9 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.invalidated def test_noreconnect_rollback_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("rollback_no_disconnect") + self.dbapi.shutdown("rollback_no_disconnect") # raises error assert_raises_message( @@ -266,13 +295,13 @@ class MockReconnectTest(fixtures.TestBase): ) def test_reconnect_on_reentrant(self): - conn = db.connect() + conn = self.db.connect() conn.execute(select([1])) - assert len(dbapi.connections) == 1 + assert len(self.dbapi.connections) == 1 - dbapi.shutdown("rollback") + self.dbapi.shutdown("rollback") # raises error assert_raises_message( @@ -285,9 +314,9 @@ class MockReconnectTest(fixtures.TestBase): assert conn.invalidated def test_reconnect_on_reentrant_plus_closewresult(self): - conn = db.connect(close_with_result=True) + conn = self.db.connect(close_with_result=True) - dbapi.shutdown("rollback") + self.dbapi.shutdown("rollback") # raises error assert_raises_message( @@ -306,10 +335,11 @@ class MockReconnectTest(fixtures.TestBase): ) def test_check_disconnect_no_cursor(self): - conn = db.connect() - result = conn.execute("select 1") + conn = self.db.connect() + result = conn.execute(select([1])) result.cursor.close() conn.close() + assert_raises_message( tsa.exc.DBAPIError, "cursor closed", @@ -319,60 +349,59 @@ class MockReconnectTest(fixtures.TestBase): class CursorErrTest(fixtures.TestBase): def setup(self): - global db, dbapi - - class MDBAPI(MockDBAPI): - def connect(self, *args, **kwargs): - return MConn(self) - - class MConn(MockConnection): - def cursor(self): - return MCursor(self) + def MockDBAPI(): + def cursor(): + while True: + yield Mock( + description=[], + close=Mock(side_effect=Exception("explode"))) + def connect(): + while True: + yield Mock(cursor=Mock(side_effect=cursor())) + + return Mock(connect=Mock(side_effect=connect())) - class MCursor(MockCursor): - def close(self): - raise Exception("explode") - - dbapi = MDBAPI() - - db = testing_engine( + dbapi = MockDBAPI() + self.db = testing_engine( 'postgresql://foo:bar@localhost/test', options=dict(module=dbapi, _initialize=False)) def test_cursor_explode(self): - conn = db.connect() + conn = self.db.connect() result = conn.execute("select foo") result.close() conn.close() def teardown(self): - db.dispose() + self.db.dispose() + + +def _assert_invalidated(fn, *args): + try: + fn(*args) + assert False + except tsa.exc.DBAPIError as e: + if not e.connection_invalidated: + raise -engine = None class RealReconnectTest(fixtures.TestBase): def setup(self): - global engine - engine = engines.reconnecting_engine() + self.engine = engines.reconnecting_engine() def teardown(self): - engine.dispose() + self.engine.dispose() @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_reconnect(self): - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() + self.engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated @@ -382,13 +411,9 @@ class RealReconnectTest(fixtures.TestBase): assert not conn.invalidated # one more time - engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select([1])) + assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated @@ -396,30 +421,22 @@ class RealReconnectTest(fixtures.TestBase): conn.close() def test_multiple_invalidate(self): - c1 = engine.connect() - c2 = engine.connect() + c1 = self.engine.connect() + c2 = self.engine.connect() eq_(c1.execute(select([1])).scalar(), 1) - p1 = engine.pool - engine.test_shutdown() + p1 = self.engine.pool + self.engine.test_shutdown() - try: - c1.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - assert e.connection_invalidated + _assert_invalidated(c1.execute, select([1])) - p2 = engine.pool + p2 = self.engine.pool - try: - c2.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - assert e.connection_invalidated + _assert_invalidated(c2.execute, select([1])) # pool isn't replaced - assert engine.pool is p2 + assert self.engine.pool is p2 def test_ensure_is_disconnect_gets_connection(self): @@ -430,37 +447,37 @@ class RealReconnectTest(fixtures.TestBase): # though MySQLdb we get a non-working cursor. # assert cursor is None - engine.dialect.is_disconnect = is_disconnect - conn = engine.connect() - engine.test_shutdown() + self.engine.dialect.is_disconnect = is_disconnect + conn = self.engine.connect() + self.engine.test_shutdown() assert_raises( tsa.exc.DBAPIError, conn.execute, select([1]) ) def test_rollback_on_invalid_plain(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() conn.invalidate() trans.rollback() @testing.requires.two_phase_transactions def test_rollback_on_invalid_twophase(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin_twophase() conn.invalidate() trans.rollback() @testing.requires.savepoints def test_rollback_on_invalid_savepoint(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() trans2 = conn.begin_nested() conn.invalidate() trans2.rollback() def test_invalidate_twice(self): - conn = engine.connect() + conn = self.engine.connect() conn.invalidate() conn.invalidate() @@ -503,12 +520,7 @@ class RealReconnectTest(fixtures.TestBase): eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated eq_(conn.execute(select([1])).scalar(), 1) @@ -517,37 +529,27 @@ class RealReconnectTest(fixtures.TestBase): @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_close(self): - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() + self.engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + _assert_invalidated(conn.execute, select([1])) conn.close() - conn = engine.connect() + conn = self.engine.connect() eq_(conn.execute(select([1])).scalar(), 1) @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_with_transaction(self): - conn = engine.connect() + conn = self.engine.connect() trans = conn.begin() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed - engine.test_shutdown() - try: - conn.execute(select([1])) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select([1])) assert not conn.closed assert conn.invalidated assert trans.is_active @@ -558,13 +560,11 @@ class RealReconnectTest(fixtures.TestBase): conn.execute, select([1]) ) assert trans.is_active - try: - trans.commit() - assert False - except tsa.exc.InvalidRequestError as e: - assert str(e) \ - == "Can't reconnect until invalid transaction is "\ - "rolled back" + assert_raises_message( + tsa.exc.InvalidRequestError, + "Can't reconnect until invalid transaction is rolled back", + trans.commit + ) assert trans.is_active trans.rollback() assert not trans.is_active @@ -602,23 +602,21 @@ class RecycleTest(fixtures.TestBase): eq_(conn.execute(select([1])).scalar(), 1) conn.close() -meta, table, engine = None, None, None class InvalidateDuringResultTest(fixtures.TestBase): def setup(self): - global meta, table, engine - engine = engines.reconnecting_engine() - meta = MetaData(engine) - table = Table('sometable', meta, + self.engine = engines.reconnecting_engine() + self.meta = MetaData(self.engine) + table = Table('sometable', self.meta, Column('id', Integer, primary_key=True), Column('name', String(50))) - meta.create_all() + self.meta.create_all() table.insert().execute( - [{'id':i, 'name':'row %d' % i} for i in range(1, 100)] + [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)] ) def teardown(self): - meta.drop_all() - engine.dispose() + self.meta.drop_all() + self.engine.dispose() @testing.fails_if([ '+mysqlconnector', '+mysqldb', @@ -628,16 +626,11 @@ class InvalidateDuringResultTest(fixtures.TestBase): @testing.fails_on('+informixdb', "Wrong error thrown, fix in informixdb?") def test_invalidate_on_results(self): - conn = engine.connect() + conn = self.engine.connect() result = conn.execute('select * from sometable') for x in range(20): result.fetchone() - engine.test_shutdown() - try: - print('ghost result: %r' % result.fetchone()) - assert False - except tsa.exc.DBAPIError as e: - if not e.connection_invalidated: - raise + self.engine.test_shutdown() + _assert_invalidated(result.fetchone) assert conn.invalidated |