summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py224
1 files changed, 131 insertions, 93 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 4e7453d84..6b5c4a072 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -4,12 +4,12 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import weakref
+
from sqlalchemy import util, exceptions, sql, engine
-from sqlalchemy.orm import unitofwork, query
+from sqlalchemy.orm import unitofwork, query, util as mapperutil
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
-import weakref
-import sqlalchemy
class SessionTransaction(object):
"""Represents a Session-level Transaction.
@@ -21,70 +21,95 @@ class SessionTransaction(object):
The SessionTransaction object is **not** threadsafe.
"""
- def __init__(self, session, parent=None, autoflush=True):
+ def __init__(self, session, parent=None, autoflush=True, nested=False):
self.session = session
- self.connections = {}
- self.parent = parent
+ self.__connections = {}
+ self.__parent = parent
self.autoflush = autoflush
+ self.nested = nested
- def connection(self, mapper_or_class, entity_name=None):
+ def connection(self, mapper_or_class, entity_name=None, **kwargs):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
- engine = self.session.get_bind(mapper_or_class)
+ engine = self.session.get_bind(mapper_or_class, **kwargs)
return self.get_or_add(engine)
- def _begin(self):
- return SessionTransaction(self.session, self)
+ def _begin(self, **kwargs):
+ return SessionTransaction(self.session, self, **kwargs)
def add(self, bind):
- if self.parent is not None:
- return self.parent.add(bind)
-
- if self.connections.has_key(bind.engine):
- raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or ""))
+ if self.__parent is not None:
+ return self.__parent.add(bind)
+ if self.__connections.has_key(bind.engine):
+ raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or ""))
return self.get_or_add(bind)
+ def _connection_dict(self):
+ if self.__parent is not None and not self.nested:
+ return self.__parent._connection_dict()
+ else:
+ return self.__connections
+
def get_or_add(self, bind):
- if self.parent is not None:
- return self.parent.get_or_add(bind)
+ if self.__parent is not None:
+ if not self.nested:
+ return self.__parent.get_or_add(bind)
+
+ if self.__connections.has_key(bind):
+ return self.__connections[bind][0]
+
+ if bind in self.__parent._connection_dict():
+ (conn, trans, autoclose) = self.__parent.__connections[bind]
+ self.__connections[conn] = self.__connections[bind.engine] = (conn, conn.begin_nested(), autoclose)
+ return conn
+ elif self.__connections.has_key(bind):
+ return self.__connections[bind][0]
- if self.connections.has_key(bind):
- return self.connections[bind][0]
-
if not isinstance(bind, engine.Connection):
e = bind
c = bind.contextual_connect()
else:
e = bind.engine
c = bind
- if e in self.connections:
+ if e in self.__connections:
raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
-
- self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind)
- return self.connections[bind][0]
+ if self.nested:
+ trans = c.begin_nested()
+ elif self.session.twophase:
+ trans = c.begin_twophase()
+ else:
+ trans = c.begin()
+ self.__connections[c] = self.__connections[e] = (c, trans, c is not bind)
+ return self.__connections[c][0]
def commit(self):
- if self.parent is not None:
- return
+ if self.__parent is not None and not self.nested:
+ return self.__parent
if self.autoflush:
self.session.flush()
- for t in util.Set(self.connections.values()):
+
+ if self.session.twophase:
+ for t in util.Set(self.__connections.values()):
+ t[1].prepare()
+
+ for t in util.Set(self.__connections.values()):
t[1].commit()
self.close()
+ return self.__parent
def rollback(self):
- if self.parent is not None:
- self.parent.rollback()
- return
- for k, t in self.connections.iteritems():
+ if self.__parent is not None and not self.nested:
+ return self.__parent.rollback()
+ for t in util.Set(self.__connections.values()):
t[1].rollback()
self.close()
-
+ return self.__parent
+
def close(self):
- if self.parent is not None:
+ if self.__parent is not None:
return
- for t in self.connections.values():
+ for t in util.Set(self.__connections.values()):
if t[2]:
t[0].close()
self.session.transaction = None
@@ -108,23 +133,24 @@ class Session(object):
of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module.
"""
- def __init__(self, bind=None, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False):
- if import_session is not None:
- self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map, weak_identity_map=weak_identity_map)
- else:
- self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
+ def __init__(self, bind=None, autoflush=False, transactional=False, twophase=False, echo_uow=False, weak_identity_map=False):
+ self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
- self.bind = bind or bind_to
- self.binds = {}
+ self.bind = bind
+ self.__binds = {}
self.echo_uow = echo_uow
self.weak_identity_map = weak_identity_map
self.transaction = None
- if hash_key is None:
- self.hash_key = id(self)
- else:
- self.hash_key = hash_key
+ self.hash_key = id(self)
+ self.autoflush = autoflush
+ self.transactional = transactional or autoflush
+ self.twophase = twophase
+ self._query_cls = query.Query
+ self._mapper_flush_opts = {}
+ if self.transactional:
+ self.begin()
_sessions[self.hash_key] = self
-
+
def _get_echo_uow(self):
return self.uow.echo
@@ -132,37 +158,39 @@ class Session(object):
self.uow.echo = value
echo_uow = property(_get_echo_uow,_set_echo_uow)
- bind_to = property(lambda self:self.bind)
-
- def create_transaction(self, **kwargs):
- """Return a new ``SessionTransaction`` corresponding to an
- existing or new transaction.
-
- If the transaction is new, the returned ``SessionTransaction``
- will have commit control over the underlying transaction, else
- will have rollback control only.
- """
+ def begin(self, **kwargs):
+ """Begin a transaction on this Session."""
if self.transaction is not None:
- return self.transaction._begin()
+ self.transaction = self.transaction._begin(**kwargs)
else:
self.transaction = SessionTransaction(self, **kwargs)
- return self.transaction
-
- def connect(self, mapper=None, **kwargs):
- """Return a unique connection corresponding to the given mapper.
-
- This connection will not be part of any pre-existing
- transactional context.
- """
-
- return self.get_bind(mapper).connect(**kwargs)
-
- def connection(self, mapper, **kwargs):
- """Return a ``Connection`` corresponding to the given mapper.
+ return self.transaction
+
+ create_transaction = begin
- Used by the ``execute()`` method which performs select
- operations for ``Mapper`` and ``Query``.
+ def begin_nested(self):
+ return self.begin(nested=True)
+
+ def rollback(self):
+ if self.transaction is None:
+ raise exceptions.InvalidRequestError("No transaction is begun.")
+ else:
+ self.transaction = self.transaction.rollback()
+ if self.transaction is None and self.transactional:
+ self.begin()
+
+ def commit(self):
+ if self.transaction is None:
+ raise exceptions.InvalidRequestError("No transaction is begun.")
+ else:
+ self.transaction = self.transaction.commit()
+ if self.transaction is None and self.transactional:
+ self.begin()
+
+ def connection(self, mapper=None, **kwargs):
+ """Return a ``Connection`` corresponding to this session's
+ transactional context, if any.
If this ``Session`` is transactional, the connection will be in
the context of this session's transaction. Otherwise, the
@@ -173,6 +201,9 @@ class Session(object):
The given `**kwargs` will be sent to the engine's
``contextual_connect()`` method, if no transaction is in
progress.
+
+ the "mapper" argument is a class or mapper to which a bound engine
+ will be located; use this when the Session itself is unbound.
"""
if self.transaction is not None:
@@ -180,7 +211,7 @@ class Session(object):
else:
return self.get_bind(mapper).contextual_connect(**kwargs)
- def execute(self, mapper, clause, params, **kwargs):
+ def execute(self, clause, params=None, mapper=None, **kwargs):
"""Using the given mapper to identify the appropriate ``Engine``
or ``Connection`` to be used for statement execution, execute the
given ``ClauseElement`` using the provided parameter dictionary.
@@ -191,12 +222,12 @@ class Session(object):
then the ``ResultProxy`` 's ``close()`` method will release the
resources of the underlying ``Connection``, otherwise its a no-op.
"""
- return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs)
+ return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs)
- def scalar(self, mapper, clause, params, **kwargs):
+ def scalar(self, clause, params=None, mapper=None, **kwargs):
"""Like execute() but return a scalar result."""
- return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs)
+ return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs)
def close(self):
"""Close this Session."""
@@ -224,14 +255,17 @@ class Session(object):
return _class_mapper(class_, entity_name = entity_name)
- def bind_mapper(self, mapper, bind):
- """Bind the given `mapper` to the given ``Engine`` or ``Connection``.
+ def bind_mapper(self, mapper, bind, entity_name=None):
+ """Bind the given `mapper` or `class` to the given ``Engine`` or ``Connection``.
All subsequent operations involving this ``Mapper`` will use the
given `bind`.
"""
+
+ if isinstance(mapper, type):
+ mapper = _class_mapper(mapper, entity_name=entity_name)
- self.binds[mapper] = bind
+ self.__binds[mapper] = bind
def bind_table(self, table, bind):
"""Bind the given `table` to the given ``Engine`` or ``Connection``.
@@ -240,7 +274,7 @@ class Session(object):
given `bind`.
"""
- self.binds[table] = bind
+ self.__binds[table] = bind
def get_bind(self, mapper):
"""Return the ``Engine`` or ``Connection`` which is used to execute
@@ -270,15 +304,18 @@ class Session(object):
"""
if mapper is None:
- return self.bind
- elif self.binds.has_key(mapper):
- return self.binds[mapper]
- elif self.binds.has_key(mapper.mapped_table):
- return self.binds[mapper.mapped_table]
+ if self.bind is not None:
+ return self.bind
+ else:
+ raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
+ elif self.__binds.has_key(mapper):
+ return self.__binds[mapper]
+ elif self.__binds.has_key(mapper.mapped_table):
+ return self.__binds[mapper.mapped_table]
elif self.bind is not None:
return self.bind
else:
- e = mapper.mapped_table.engine
+ e = mapper.mapped_table.bind
if e is None:
raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
return e
@@ -291,9 +328,9 @@ class Session(object):
entity_name = kwargs.pop('entity_name', None)
if isinstance(mapper_or_class, type):
- q = query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
+ q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
else:
- q = query.Query(mapper_or_class, self, **kwargs)
+ q = self._query_cls(mapper_or_class, self, **kwargs)
for ent in addtl_entities:
q = q.add_entity(ent)
@@ -499,7 +536,7 @@ class Session(object):
merged = self.get(mapper.class_, key[1])
if merged is None:
raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object))
- for prop in mapper.props.values():
+ for prop in mapper.iterate_properties:
prop.merge(self, object, merged, _recursive)
if key is None:
self.save(merged, entity_name=mapper.entity_name)
@@ -611,12 +648,12 @@ class Session(object):
def _attach(self, obj):
"""Attach the given object to this ``Session``."""
- if getattr(obj, '_sa_session_id', None) != self.hash_key:
- old = getattr(obj, '_sa_session_id', None)
- if old is not None and _sessions.has_key(old):
+ old_id = getattr(obj, '_sa_session_id', None)
+ if old_id != self.hash_key:
+ if old_id is not None and _sessions.has_key(old_id):
raise exceptions.InvalidRequestError("Object '%s' is already attached "
"to session '%s' (this is '%s')" %
- (repr(obj), old, id(self)))
+ (repr(obj), old_id, id(self)))
# auto-removal from the old session is disabled. but if we decide to
# turn it back on, do it as below: gingerly since _sessions is a WeakValueDict
@@ -695,6 +732,7 @@ def object_session(obj):
return _sessions.get(hashkey)
return None
+# Lazy initialization to avoid circular imports
unitofwork.object_session = object_session
from sqlalchemy.orm import mapper
mapper.attribute_manager = attribute_manager