summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/databases/oracle.py6
-rw-r--r--lib/sqlalchemy/databases/postgres.py8
-rw-r--r--lib/sqlalchemy/engine/base.py56
-rw-r--r--test/sql/constraints.py4
4 files changed, 41 insertions, 33 deletions
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index fb5b512e2..5b852c185 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -672,12 +672,8 @@ class OracleSchemaDropper(compiler.SchemaDropper):
self.execute()
class OracleDefaultRunner(base.DefaultRunner):
- def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
- return self.connection.execute(c).scalar()
-
def visit_sequence(self, seq):
- return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
+ return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL")
class OracleIdentifierPreparer(compiler.IdentifierPreparer):
def format_savepoint(self, savepoint):
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 701114b17..f5eaf47f2 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -611,9 +611,9 @@ class PGSchemaDropper(compiler.SchemaDropper):
class PGDefaultRunner(base.DefaultRunner):
def get_column_default(self, column, isinsert=True):
if column.primary_key:
- # passive defaults on primary keys have to be overridden
+ # pre-execute passive defaults on primary keys
if isinstance(column.default, schema.PassiveDefault):
- return self.connection.execute("select %s" % column.default.arg).scalar()
+ return self.execute_string("select %s" % column.default.arg)
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
@@ -622,13 +622,13 @@ class PGDefaultRunner(base.DefaultRunner):
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.connection.execute(exc).scalar()
+ return self.execute_string(exc.encode(self.dialect.encoding))
return super(PGDefaultRunner, self).get_column_default(column)
def visit_sequence(self, seq):
if not seq.optional:
- return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar()
+ return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).encode(self.dialect.encoding))
else:
return None
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 1ab05fe03..45d84f90d 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -837,45 +837,50 @@ class Connection(Connectable):
return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
def __execute_raw(self, context):
- if self.__engine._should_log_info:
- self.__engine.logger.info(context.statement)
- self.__engine.logger.info(repr(context.parameters))
if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)):
- self.__executemany(context)
+ self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context)
else:
- self.__execute(context)
+ if context.parameters is None:
+ if context.dialect.positional:
+ parameters = ()
+ else:
+ parameters = {}
+ else:
+ parameters = context.parameters
+ self._cursor_execute(context.cursor, context.statement, parameters, context=context)
self._autocommit(context)
- def __execute(self, context):
- if context.parameters is None:
- if context.dialect.positional:
- context.parameters = ()
- else:
- context.parameters = {}
+ def _cursor_execute(self, cursor, statement, parameters, context=None):
+ if self.__engine._should_log_info:
+ self.__engine.logger.info(statement)
+ self.__engine.logger.info(repr(parameters))
try:
- context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context)
+ self.dialect.do_execute(cursor, statement, parameters, context=context)
except Exception, e:
if self.dialect.is_disconnect(e):
self.__connection.invalidate(e=e)
self.engine.dispose()
- context.cursor.close()
+ cursor.close()
self._autorollback()
if self.__close_with_result:
self.close()
- raise exceptions.DBAPIError.instance(context.statement, context.parameters, e)
+ raise exceptions.DBAPIError.instance(statement, parameters, e)
- def __executemany(self, context):
+ def _cursor_executemany(self, cursor, statement, parameters, context=None):
+ if self.__engine._should_log_info:
+ self.__engine.logger.info(statement)
+ self.__engine.logger.info(repr(parameters))
try:
- context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
+ self.dialect.do_executemany(cursor, statement, parameters, context=context)
except Exception, e:
if self.dialect.is_disconnect(e):
self.__connection.invalidate(e=e)
self.engine.dispose()
- context.cursor.close()
+ cursor.close()
self._autorollback()
if self.__close_with_result:
self.close()
- raise exceptions.DBAPIError.instance(context.statement, context.parameters, e)
+ raise exceptions.DBAPIError.instance(statement, parameters, e)
# poor man's multimethod/generic function thingy
executors = {
@@ -1632,7 +1637,6 @@ class DefaultRunner(schema.SchemaVisitor):
def __init__(self, context):
self.context = context
- self.connection = context._connection._branch()
self.dialect = context.dialect
def get_column_default(self, column):
@@ -1665,9 +1669,17 @@ class DefaultRunner(schema.SchemaVisitor):
return None
def exec_default_sql(self, default):
- c = expression.select([default.arg]).compile(bind=self.connection)
- return self.connection._execute_compiled(c).scalar()
-
+ conn = self.context.connection
+ c = expression.select([default.arg]).compile(bind=conn)
+ return conn._execute_compiled(c).scalar()
+
+ def execute_string(self, stmt, params=None):
+ """execute a string statement, using the raw cursor,
+ and return a scalar result."""
+ conn = self.context._connection
+ conn._cursor_execute(self.context.cursor, stmt, params)
+ return self.context.cursor.fetchone()[0]
+
def visit_column_onupdate(self, onupdate):
if isinstance(onupdate.arg, expression.ClauseElement):
return self.exec_default_sql(onupdate)
diff --git a/test/sql/constraints.py b/test/sql/constraints.py
index 93ba231ab..c5320ada3 100644
--- a/test/sql/constraints.py
+++ b/test/sql/constraints.py
@@ -174,12 +174,12 @@ class ConstraintTest(AssertMixin):
capt = []
connection = testbase.db.connect()
# TODO: hacky, put a real connection proxy in
- ex = connection._Connection__execute
+ ex = connection._Connection__execute_raw
def proxy(context):
capt.append(context.statement)
capt.append(repr(context.parameters))
ex(context)
- connection._Connection__execute = proxy
+ connection._Connection__execute_raw = proxy
schemagen = testbase.db.dialect.schemagenerator(testbase.db.dialect, connection)
schemagen.traverse(events)