diff options
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 224 |
1 files changed, 131 insertions, 93 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 4e7453d84..6b5c4a072 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -4,12 +4,12 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import weakref + from sqlalchemy import util, exceptions, sql, engine -from sqlalchemy.orm import unitofwork, query +from sqlalchemy.orm import unitofwork, query, util as mapperutil from sqlalchemy.orm.mapper import object_mapper as _object_mapper from sqlalchemy.orm.mapper import class_mapper as _class_mapper -import weakref -import sqlalchemy class SessionTransaction(object): """Represents a Session-level Transaction. @@ -21,70 +21,95 @@ class SessionTransaction(object): The SessionTransaction object is **not** threadsafe. """ - def __init__(self, session, parent=None, autoflush=True): + def __init__(self, session, parent=None, autoflush=True, nested=False): self.session = session - self.connections = {} - self.parent = parent + self.__connections = {} + self.__parent = parent self.autoflush = autoflush + self.nested = nested - def connection(self, mapper_or_class, entity_name=None): + def connection(self, mapper_or_class, entity_name=None, **kwargs): if isinstance(mapper_or_class, type): mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name) - engine = self.session.get_bind(mapper_or_class) + engine = self.session.get_bind(mapper_or_class, **kwargs) return self.get_or_add(engine) - def _begin(self): - return SessionTransaction(self.session, self) + def _begin(self, **kwargs): + return SessionTransaction(self.session, self, **kwargs) def add(self, bind): - if self.parent is not None: - return self.parent.add(bind) - - if self.connections.has_key(bind.engine): - raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or "")) + if self.__parent is not None: + return self.__parent.add(bind) + if self.__connections.has_key(bind.engine): + raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or "")) return self.get_or_add(bind) + def _connection_dict(self): + if self.__parent is not None and not self.nested: + return self.__parent._connection_dict() + else: + return self.__connections + def get_or_add(self, bind): - if self.parent is not None: - return self.parent.get_or_add(bind) + if self.__parent is not None: + if not self.nested: + return self.__parent.get_or_add(bind) + + if self.__connections.has_key(bind): + return self.__connections[bind][0] + + if bind in self.__parent._connection_dict(): + (conn, trans, autoclose) = self.__parent.__connections[bind] + self.__connections[conn] = self.__connections[bind.engine] = (conn, conn.begin_nested(), autoclose) + return conn + elif self.__connections.has_key(bind): + return self.__connections[bind][0] - if self.connections.has_key(bind): - return self.connections[bind][0] - if not isinstance(bind, engine.Connection): e = bind c = bind.contextual_connect() else: e = bind.engine c = bind - if e in self.connections: + if e in self.__connections: raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") - - self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind) - return self.connections[bind][0] + if self.nested: + trans = c.begin_nested() + elif self.session.twophase: + trans = c.begin_twophase() + else: + trans = c.begin() + self.__connections[c] = self.__connections[e] = (c, trans, c is not bind) + return self.__connections[c][0] def commit(self): - if self.parent is not None: - return + if self.__parent is not None and not self.nested: + return self.__parent if self.autoflush: self.session.flush() - for t in util.Set(self.connections.values()): + + if self.session.twophase: + for t in util.Set(self.__connections.values()): + t[1].prepare() + + for t in util.Set(self.__connections.values()): t[1].commit() self.close() + return self.__parent def rollback(self): - if self.parent is not None: - self.parent.rollback() - return - for k, t in self.connections.iteritems(): + if self.__parent is not None and not self.nested: + return self.__parent.rollback() + for t in util.Set(self.__connections.values()): t[1].rollback() self.close() - + return self.__parent + def close(self): - if self.parent is not None: + if self.__parent is not None: return - for t in self.connections.values(): + for t in util.Set(self.__connections.values()): if t[2]: t[0].close() self.session.transaction = None @@ -108,23 +133,24 @@ class Session(object): of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module. """ - def __init__(self, bind=None, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False): - if import_session is not None: - self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map, weak_identity_map=weak_identity_map) - else: - self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map) + def __init__(self, bind=None, autoflush=False, transactional=False, twophase=False, echo_uow=False, weak_identity_map=False): + self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map) - self.bind = bind or bind_to - self.binds = {} + self.bind = bind + self.__binds = {} self.echo_uow = echo_uow self.weak_identity_map = weak_identity_map self.transaction = None - if hash_key is None: - self.hash_key = id(self) - else: - self.hash_key = hash_key + self.hash_key = id(self) + self.autoflush = autoflush + self.transactional = transactional or autoflush + self.twophase = twophase + self._query_cls = query.Query + self._mapper_flush_opts = {} + if self.transactional: + self.begin() _sessions[self.hash_key] = self - + def _get_echo_uow(self): return self.uow.echo @@ -132,37 +158,39 @@ class Session(object): self.uow.echo = value echo_uow = property(_get_echo_uow,_set_echo_uow) - bind_to = property(lambda self:self.bind) - - def create_transaction(self, **kwargs): - """Return a new ``SessionTransaction`` corresponding to an - existing or new transaction. - - If the transaction is new, the returned ``SessionTransaction`` - will have commit control over the underlying transaction, else - will have rollback control only. - """ + def begin(self, **kwargs): + """Begin a transaction on this Session.""" if self.transaction is not None: - return self.transaction._begin() + self.transaction = self.transaction._begin(**kwargs) else: self.transaction = SessionTransaction(self, **kwargs) - return self.transaction - - def connect(self, mapper=None, **kwargs): - """Return a unique connection corresponding to the given mapper. - - This connection will not be part of any pre-existing - transactional context. - """ - - return self.get_bind(mapper).connect(**kwargs) - - def connection(self, mapper, **kwargs): - """Return a ``Connection`` corresponding to the given mapper. + return self.transaction + + create_transaction = begin - Used by the ``execute()`` method which performs select - operations for ``Mapper`` and ``Query``. + def begin_nested(self): + return self.begin(nested=True) + + def rollback(self): + if self.transaction is None: + raise exceptions.InvalidRequestError("No transaction is begun.") + else: + self.transaction = self.transaction.rollback() + if self.transaction is None and self.transactional: + self.begin() + + def commit(self): + if self.transaction is None: + raise exceptions.InvalidRequestError("No transaction is begun.") + else: + self.transaction = self.transaction.commit() + if self.transaction is None and self.transactional: + self.begin() + + def connection(self, mapper=None, **kwargs): + """Return a ``Connection`` corresponding to this session's + transactional context, if any. If this ``Session`` is transactional, the connection will be in the context of this session's transaction. Otherwise, the @@ -173,6 +201,9 @@ class Session(object): The given `**kwargs` will be sent to the engine's ``contextual_connect()`` method, if no transaction is in progress. + + the "mapper" argument is a class or mapper to which a bound engine + will be located; use this when the Session itself is unbound. """ if self.transaction is not None: @@ -180,7 +211,7 @@ class Session(object): else: return self.get_bind(mapper).contextual_connect(**kwargs) - def execute(self, mapper, clause, params, **kwargs): + def execute(self, clause, params=None, mapper=None, **kwargs): """Using the given mapper to identify the appropriate ``Engine`` or ``Connection`` to be used for statement execution, execute the given ``ClauseElement`` using the provided parameter dictionary. @@ -191,12 +222,12 @@ class Session(object): then the ``ResultProxy`` 's ``close()`` method will release the resources of the underlying ``Connection``, otherwise its a no-op. """ - return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs) + return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs) - def scalar(self, mapper, clause, params, **kwargs): + def scalar(self, clause, params=None, mapper=None, **kwargs): """Like execute() but return a scalar result.""" - return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs) + return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs) def close(self): """Close this Session.""" @@ -224,14 +255,17 @@ class Session(object): return _class_mapper(class_, entity_name = entity_name) - def bind_mapper(self, mapper, bind): - """Bind the given `mapper` to the given ``Engine`` or ``Connection``. + def bind_mapper(self, mapper, bind, entity_name=None): + """Bind the given `mapper` or `class` to the given ``Engine`` or ``Connection``. All subsequent operations involving this ``Mapper`` will use the given `bind`. """ + + if isinstance(mapper, type): + mapper = _class_mapper(mapper, entity_name=entity_name) - self.binds[mapper] = bind + self.__binds[mapper] = bind def bind_table(self, table, bind): """Bind the given `table` to the given ``Engine`` or ``Connection``. @@ -240,7 +274,7 @@ class Session(object): given `bind`. """ - self.binds[table] = bind + self.__binds[table] = bind def get_bind(self, mapper): """Return the ``Engine`` or ``Connection`` which is used to execute @@ -270,15 +304,18 @@ class Session(object): """ if mapper is None: - return self.bind - elif self.binds.has_key(mapper): - return self.binds[mapper] - elif self.binds.has_key(mapper.mapped_table): - return self.binds[mapper.mapped_table] + if self.bind is not None: + return self.bind + else: + raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()") + elif self.__binds.has_key(mapper): + return self.__binds[mapper] + elif self.__binds.has_key(mapper.mapped_table): + return self.__binds[mapper.mapped_table] elif self.bind is not None: return self.bind else: - e = mapper.mapped_table.engine + e = mapper.mapped_table.bind if e is None: raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper)) return e @@ -291,9 +328,9 @@ class Session(object): entity_name = kwargs.pop('entity_name', None) if isinstance(mapper_or_class, type): - q = query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs) + q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs) else: - q = query.Query(mapper_or_class, self, **kwargs) + q = self._query_cls(mapper_or_class, self, **kwargs) for ent in addtl_entities: q = q.add_entity(ent) @@ -499,7 +536,7 @@ class Session(object): merged = self.get(mapper.class_, key[1]) if merged is None: raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object)) - for prop in mapper.props.values(): + for prop in mapper.iterate_properties: prop.merge(self, object, merged, _recursive) if key is None: self.save(merged, entity_name=mapper.entity_name) @@ -611,12 +648,12 @@ class Session(object): def _attach(self, obj): """Attach the given object to this ``Session``.""" - if getattr(obj, '_sa_session_id', None) != self.hash_key: - old = getattr(obj, '_sa_session_id', None) - if old is not None and _sessions.has_key(old): + old_id = getattr(obj, '_sa_session_id', None) + if old_id != self.hash_key: + if old_id is not None and _sessions.has_key(old_id): raise exceptions.InvalidRequestError("Object '%s' is already attached " "to session '%s' (this is '%s')" % - (repr(obj), old, id(self))) + (repr(obj), old_id, id(self))) # auto-removal from the old session is disabled. but if we decide to # turn it back on, do it as below: gingerly since _sessions is a WeakValueDict @@ -695,6 +732,7 @@ def object_session(obj): return _sessions.get(hashkey) return None +# Lazy initialization to avoid circular imports unitofwork.object_session = object_session from sqlalchemy.orm import mapper mapper.attribute_manager = attribute_manager |