summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--osprofiler/sqlalchemy.py11
-rw-r--r--osprofiler/tests/test_sqlalchemy.py31
2 files changed, 42 insertions, 0 deletions
diff --git a/osprofiler/sqlalchemy.py b/osprofiler/sqlalchemy.py
index e98c59c..c593684 100644
--- a/osprofiler/sqlalchemy.py
+++ b/osprofiler/sqlalchemy.py
@@ -13,6 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
+import contextlib
+
from osprofiler import profiler
@@ -42,6 +44,15 @@ def add_tracing(sqlalchemy, engine, name):
_after_cursor_execute())
+@contextlib.contextmanager
+def wrap_session(sqlalchemy, sess):
+ with sess as s:
+ if not getattr(s.bind, "traced", False):
+ add_tracing(sqlalchemy, s.bind, "db")
+ s.bind.traced = True
+ yield s
+
+
def _before_cursor_execute(name):
"""Add listener that will send trace info before query is executed."""
diff --git a/osprofiler/tests/test_sqlalchemy.py b/osprofiler/tests/test_sqlalchemy.py
index 1494da6..83ada93 100644
--- a/osprofiler/tests/test_sqlalchemy.py
+++ b/osprofiler/tests/test_sqlalchemy.py
@@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
+import contextlib
import mock
from osprofiler import sqlalchemy
@@ -56,6 +57,36 @@ class SqlalchemyTracingTestCase(test.TestCase):
@mock.patch("osprofiler.sqlalchemy._before_cursor_execute")
@mock.patch("osprofiler.sqlalchemy._after_cursor_execute")
+ def test_wrap_session(self, mock_after_exc, mock_before_exc):
+ sa = mock.MagicMock()
+
+ @contextlib.contextmanager
+ def _session():
+ session = mock.MagicMock()
+ # current engine object stored within the session
+ session.bind = mock.MagicMock()
+ session.bind.traced = None
+ yield session
+
+ mock_before_exc.return_value = "before"
+ mock_after_exc.return_value = "after"
+
+ session = sqlalchemy.wrap_session(sa, _session())
+
+ with session as sess:
+ pass
+
+ mock_before_exc.assert_called_once_with("db")
+ mock_after_exc.assert_called_once_with()
+ expected_calls = [
+ mock.call(sess.bind, "before_cursor_execute", "before"),
+ mock.call(sess.bind, "after_cursor_execute", "after")
+ ]
+
+ self.assertEqual(sa.event.listen.call_args_list, expected_calls)
+
+ @mock.patch("osprofiler.sqlalchemy._before_cursor_execute")
+ @mock.patch("osprofiler.sqlalchemy._after_cursor_execute")
def test_disable_and_enable(self, mock_after_exc, mock_before_exc):
sqlalchemy.disable()