summaryrefslogtreecommitdiff
path: root/test/engine/test_parseconnect.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/engine/test_parseconnect.py')
-rw-r--r--test/engine/test_parseconnect.py64
1 files changed, 25 insertions, 39 deletions
diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py
index 73bdc76c4..106bd0782 100644
--- a/test/engine/test_parseconnect.py
+++ b/test/engine/test_parseconnect.py
@@ -7,6 +7,8 @@ from sqlalchemy.engine.default import DefaultDialect
import sqlalchemy as tsa
from sqlalchemy.testing import fixtures
from sqlalchemy import testing
+from sqlalchemy.testing.mock import Mock
+
class ParseConnectTest(fixtures.TestBase):
def test_rfc1738(self):
@@ -250,20 +252,17 @@ pool_timeout=10
every backend.
"""
- # pretend pysqlite throws the
- # "Cannot operate on a closed database." error
- # on connect. IRL we'd be getting Oracle's "shutdown in progress"
e = create_engine('sqlite://')
sqlite3 = e.dialect.dbapi
- class ThrowOnConnect(MockDBAPI):
- dbapi = sqlite3
- Error = sqlite3.Error
- ProgrammingError = sqlite3.ProgrammingError
- def connect(self, *args, **kw):
- raise sqlite3.ProgrammingError("Cannot operate on a closed database.")
+
+ dbapi = MockDBAPI()
+ dbapi.Error = sqlite3.Error,
+ dbapi.ProgrammingError = sqlite3.ProgrammingError
+ dbapi.connect = Mock(side_effect=sqlite3.ProgrammingError(
+ "Cannot operate on a closed database."))
try:
- create_engine('sqlite://', module=ThrowOnConnect()).connect()
+ create_engine('sqlite://', module=dbapi).connect()
assert False
except tsa.exc.DBAPIError as de:
assert de.connection_invalidated
@@ -354,36 +353,23 @@ class MockDialect(DefaultDialect):
def dbapi(cls, **kw):
return MockDBAPI()
-class MockDBAPI(object):
- version_info = sqlite_version_info = 99, 9, 9
- sqlite_version = '99.9.9'
-
- def __init__(self, **kwargs):
- self.kwargs = kwargs
- self.paramstyle = 'named'
-
- def connect(self, *args, **kwargs):
- for k in self.kwargs:
+def MockDBAPI(**assert_kwargs):
+ connection = Mock(get_server_version_info=Mock(return_value='5.0'))
+ def connect(*args, **kwargs):
+ for k in assert_kwargs:
assert k in kwargs, 'key %s not present in dictionary' % k
- assert kwargs[k] == self.kwargs[k], \
- 'value %s does not match %s' % (kwargs[k],
- self.kwargs[k])
- return MockConnection()
-
-
-class MockConnection(object):
- def get_server_info(self):
- return '5.0'
-
- def close(self):
- pass
-
- def cursor(self):
- return MockCursor()
-
-class MockCursor(object):
- def close(self):
- pass
+ eq_(
+ kwargs[k], assert_kwargs[k]
+ )
+ return connection
+
+ return Mock(
+ sqlite_version_info=(99, 9, 9,),
+ version_info=(99, 9, 9,),
+ sqlite_version='99.9.9',
+ paramstyle='named',
+ connect=Mock(side_effect=connect)
+ )
mock_dbapi = MockDBAPI()
mock_sqlite_dbapi = msd = MockDBAPI()