summaryrefslogtreecommitdiff
path: root/test/engine/test_execute.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/engine/test_execute.py')
-rw-r--r--test/engine/test_execute.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index ec255ba04..89d5c6348 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -26,6 +26,7 @@ from sqlalchemy import VARCHAR
from sqlalchemy.engine import default
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.base import Engine
+from sqlalchemy.pool import QueuePool
from sqlalchemy.sql import column
from sqlalchemy.sql import literal
from sqlalchemy.testing import assert_raises
@@ -2748,6 +2749,74 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase):
except tsa.exc.DBAPIError as de:
assert de.connection_invalidated
+ @testing.only_on("sqlite+pysqlite")
+ def test_initialize_connect_calls(self):
+ """test for :ticket:`5497`, on_connect not called twice"""
+
+ m1 = Mock()
+ cls_ = testing.db.dialect.__class__
+
+ class SomeDialect(cls_):
+ def initialize(self, connection):
+ super(SomeDialect, self).initialize(connection)
+ m1.initialize(connection)
+
+ def on_connect(self):
+ oc = super(SomeDialect, self).on_connect()
+
+ def my_on_connect(conn):
+ if oc:
+ oc(conn)
+ m1.on_connect(conn)
+
+ return my_on_connect
+
+ u1 = Mock(
+ username=None,
+ password=None,
+ host=None,
+ port=None,
+ query={},
+ database=None,
+ _instantiate_plugins=lambda kw: [],
+ _get_entrypoint=Mock(
+ return_value=Mock(get_dialect_cls=lambda u: SomeDialect)
+ ),
+ )
+ eng = create_engine(u1, poolclass=QueuePool)
+ eq_(
+ eng.name, "sqlite"
+ ) # make sure other dialects aren't getting pulled in here
+ c = eng.connect()
+ dbapi_conn_one = c.connection.connection
+ c.close()
+
+ eq_(
+ m1.mock_calls,
+ [call.on_connect(dbapi_conn_one), call.initialize(mock.ANY)],
+ )
+
+ c = eng.connect()
+
+ eq_(
+ m1.mock_calls,
+ [call.on_connect(dbapi_conn_one), call.initialize(mock.ANY)],
+ )
+
+ c2 = eng.connect()
+ dbapi_conn_two = c2.connection.connection
+
+ is_not_(dbapi_conn_one, dbapi_conn_two)
+
+ eq_(
+ m1.mock_calls,
+ [
+ call.on_connect(dbapi_conn_one),
+ call.initialize(mock.ANY),
+ call.on_connect(dbapi_conn_two),
+ ],
+ )
+
class DialectEventTest(fixtures.TestBase):
@contextmanager