diff options
-rw-r--r-- | osprofiler/sqlalchemy.py | 11 | ||||
-rw-r--r-- | osprofiler/tests/test_sqlalchemy.py | 31 |
2 files changed, 42 insertions, 0 deletions
diff --git a/osprofiler/sqlalchemy.py b/osprofiler/sqlalchemy.py index e98c59c..c593684 100644 --- a/osprofiler/sqlalchemy.py +++ b/osprofiler/sqlalchemy.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +import contextlib + from osprofiler import profiler @@ -42,6 +44,15 @@ def add_tracing(sqlalchemy, engine, name): _after_cursor_execute()) +@contextlib.contextmanager +def wrap_session(sqlalchemy, sess): + with sess as s: + if not getattr(s.bind, "traced", False): + add_tracing(sqlalchemy, s.bind, "db") + s.bind.traced = True + yield s + + def _before_cursor_execute(name): """Add listener that will send trace info before query is executed.""" diff --git a/osprofiler/tests/test_sqlalchemy.py b/osprofiler/tests/test_sqlalchemy.py index 1494da6..83ada93 100644 --- a/osprofiler/tests/test_sqlalchemy.py +++ b/osprofiler/tests/test_sqlalchemy.py @@ -13,6 +13,7 @@ # License for the specific language governing permissions and limitations # under the License. +import contextlib import mock from osprofiler import sqlalchemy @@ -56,6 +57,36 @@ class SqlalchemyTracingTestCase(test.TestCase): @mock.patch("osprofiler.sqlalchemy._before_cursor_execute") @mock.patch("osprofiler.sqlalchemy._after_cursor_execute") + def test_wrap_session(self, mock_after_exc, mock_before_exc): + sa = mock.MagicMock() + + @contextlib.contextmanager + def _session(): + session = mock.MagicMock() + # current engine object stored within the session + session.bind = mock.MagicMock() + session.bind.traced = None + yield session + + mock_before_exc.return_value = "before" + mock_after_exc.return_value = "after" + + session = sqlalchemy.wrap_session(sa, _session()) + + with session as sess: + pass + + mock_before_exc.assert_called_once_with("db") + mock_after_exc.assert_called_once_with() + expected_calls = [ + mock.call(sess.bind, "before_cursor_execute", "before"), + mock.call(sess.bind, "after_cursor_execute", "after") + ] + + self.assertEqual(sa.event.listen.call_args_list, expected_calls) + + @mock.patch("osprofiler.sqlalchemy._before_cursor_execute") + @mock.patch("osprofiler.sqlalchemy._after_cursor_execute") def test_disable_and_enable(self, mock_after_exc, mock_before_exc): sqlalchemy.disable() |