summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-16 21:29:27 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-16 21:29:27 +0000
commita3f25674fce8cd137dc4c1f113d5bbf25e28c7f2 (patch)
treef2eb1f773c20bf5adeb5db0456efa7b0727c477e /lib/sqlalchemy
parent40ded34819bb2a65ef02042bf60e075d65123592 (diff)
downloadsqlalchemy-a3f25674fce8cd137dc4c1f113d5bbf25e28c7f2.tar.gz
- fixes for connection bound sessions, connection-bound compiled objects via metadata
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/engine/base.py2
-rw-r--r--lib/sqlalchemy/orm/session.py34
2 files changed, 21 insertions, 15 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index dc197eb19..3a86001d5 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -509,7 +509,7 @@ class Connection(Connectable):
return self.execute(object, *multiparams, **params).scalar()
def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs)
+ return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 9a0438fc9..9e504c104 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util, exceptions, sql
+from sqlalchemy import util, exceptions, sql, engine
from sqlalchemy.orm import unitofwork, query
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
@@ -30,8 +30,6 @@ class SessionTransaction(object):
def connection(self, mapper_or_class, entity_name=None):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
- if self.parent is not None:
- return self.parent.connection(mapper_or_class)
engine = self.session.get_bind(mapper_or_class)
return self.get_or_add(engine)
@@ -39,28 +37,36 @@ class SessionTransaction(object):
return SessionTransaction(self.session, self)
def add(self, bind):
+ if self.parent is not None:
+ return self.parent.add(bind)
+
if self.connections.has_key(bind.engine):
raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
return self.get_or_add(bind)
def get_or_add(self, bind):
- # we reference the 'engine' attribute on the given object, which in the case of
- # Connection, ProxyEngine, Engine, whatever, should return the original
- # "Engine" object that is handling the connection.
- if self.connections.has_key(bind.engine):
- return self.connections[bind.engine][0]
- e = bind.engine
- c = bind.contextual_connect()
- if not self.connections.has_key(e):
- self.connections[e] = (c, c.begin(), c is not bind)
- return self.connections[e][0]
+ if self.parent is not None:
+ return self.parent.get_or_add(bind)
+
+ if self.connections.has_key(bind):
+ return self.connections[bind][0]
+
+ if not isinstance(bind, engine.Connection):
+ e = bind
+ c = bind.contextual_connect()
+ else:
+ e = bind.engine
+ c = bind
+
+ self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind)
+ return self.connections[bind][0]
def commit(self):
if self.parent is not None:
return
if self.autoflush:
self.session.flush()
- for t in self.connections.values():
+ for t in util.Set(self.connections.values()):
t[1].commit()
self.close()