summaryrefslogtreecommitdiff
path: root/test/orm/sessioncontext.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/orm/sessioncontext.py')
-rw-r--r--test/orm/sessioncontext.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/test/orm/sessioncontext.py b/test/orm/sessioncontext.py
new file mode 100644
index 000000000..83bc2f2bf
--- /dev/null
+++ b/test/orm/sessioncontext.py
@@ -0,0 +1,47 @@
+from testbase import PersistTest, AssertMixin
+import unittest, sys, os
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.session import object_session, Session
+from sqlalchemy import *
+import testbase
+
+metadata = MetaData()
+users = Table('users', metadata,
+ Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
+ Column('user_name', String(40)),
+ mysql_engine='innodb'
+)
+
+class SessionContextTest(AssertMixin):
+ def setUp(self):
+ clear_mappers()
+
+ def do_test(self, class_, context):
+ """test session assignment on object creation"""
+ obj = class_()
+ assert context.current == object_session(obj)
+
+ # keep a reference so the old session doesn't get gc'd
+ old_session = context.current
+
+ context.current = Session()
+ assert context.current != object_session(obj)
+ assert old_session == object_session(obj)
+
+ new_session = context.current
+ del context.current
+ assert context.current != new_session
+ assert old_session == object_session(obj)
+
+ obj2 = class_()
+ assert context.current == object_session(obj2)
+
+ def test_mapper_extension(self):
+ context = SessionContext(Session)
+ class User(object): pass
+ User.mapper = mapper(User, users, extension=context.mapper_extension)
+ self.do_test(User, context)
+
+
+if __name__ == "__main__":
+ testbase.main()