summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py23
-rw-r--r--lib/sqlalchemy/dialects/firebird/kinterbasdb.py1
-rw-r--r--lib/sqlalchemy/dialects/maxdb/base.py15
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py15
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py5
-rw-r--r--lib/sqlalchemy/dialects/oracle/zxjdbc.py4
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py69
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py4
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py4
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pypostgresql.py10
-rw-r--r--lib/sqlalchemy/engine/__init__.py2
-rw-r--r--lib/sqlalchemy/engine/base.py73
-rw-r--r--lib/sqlalchemy/engine/default.py67
13 files changed, 125 insertions, 167 deletions
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
index 5cc837848..4d081025e 100644
--- a/lib/sqlalchemy/dialects/firebird/base.py
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -72,6 +72,7 @@ from sqlalchemy.sql import expression
from sqlalchemy.engine import base, default, reflection
from sqlalchemy.sql import compiler
+
from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE,
FLOAT, INTEGER, NUMERIC, SMALLINT,
TEXT, TIME, TIMESTAMP, VARCHAR)
@@ -176,6 +177,8 @@ class FBTypeCompiler(compiler.GenericTypeCompiler):
return "BLOB SUB_TYPE 0"
+
+
class FBCompiler(sql.compiler.SQLCompiler):
"""Firebird specific idiosincrasies"""
@@ -280,16 +283,6 @@ class FBDDLCompiler(sql.compiler.DDLCompiler):
return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element)
-class FBDefaultRunner(base.DefaultRunner):
- """Firebird specific idiosincrasies"""
-
- def visit_sequence(self, seq):
- """Get the next value from the sequence using ``gen_id()``."""
-
- return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \
- self.dialect.identifier_preparer.format_sequence(seq))
-
-
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
"""Install Firebird specific reserved words."""
@@ -298,7 +291,13 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
def __init__(self, dialect):
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
+class FBExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq):
+ """Get the next value from the sequence using ``gen_id()``."""
+ return self._execute_scalar("SELECT gen_id(%s, 1) FROM rdb$database" % \
+ self.dialect.identifier_preparer.format_sequence(seq))
+
class FBDialect(default.DefaultDialect):
"""Firebird dialect"""
@@ -316,10 +315,10 @@ class FBDialect(default.DefaultDialect):
statement_compiler = FBCompiler
ddl_compiler = FBDDLCompiler
- defaultrunner = FBDefaultRunner
preparer = FBIdentifierPreparer
type_compiler = FBTypeCompiler
-
+ execution_ctx_cls = FBExecutionContext
+
colspecs = colspecs
ischema_names = ischema_names
diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
index 4bbd2aafc..c804af05f 100644
--- a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
+++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
@@ -28,7 +28,6 @@ __ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurr
"""
from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler
-from sqlalchemy.engine.default import DefaultExecutionContext
class Firebird_kinterbasdb(FBDialect):
diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py
index b116d6df6..c02a9e204 100644
--- a/lib/sqlalchemy/dialects/maxdb/base.py
+++ b/lib/sqlalchemy/dialects/maxdb/base.py
@@ -401,6 +401,12 @@ class MaxDBExecutionContext(default.DefaultExecutionContext):
else:
return self.cursor.rowcount
+ def fire_sequence(self, seq):
+ if seq.optional:
+ return None
+ return self._execute_scalar("SELECT %s.NEXTVAL FROM DUAL" % (
+ self.dialect.identifier_preparer.format_sequence(seq)))
+
class MaxDBCachedColumnRow(engine_base.RowProxy):
"""A RowProxy that only runs result_processors once per column."""
@@ -610,14 +616,6 @@ class MaxDBCompiler(compiler.SQLCompiler):
')'))
-class MaxDBDefaultRunner(engine_base.DefaultRunner):
- def visit_sequence(self, seq):
- if seq.optional:
- return None
- return self.execute_string("SELECT %s.NEXTVAL FROM DUAL" % (
- self.dialect.identifier_preparer.format_sequence(seq)))
-
-
class MaxDBIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = set([
'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha',
@@ -805,7 +803,6 @@ class MaxDBDialect(default.DefaultDialect):
preparer = MaxDBIdentifierPreparer
statement_compiler = MaxDBCompiler
ddl_compiler = MaxDBDDLCompiler
- defaultrunner = MaxDBDefaultRunner
execution_ctx_cls = MaxDBExecutionContext
colspecs = colspecs
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index c14151515..4f37412eb 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -454,12 +454,6 @@ class OracleDDLCompiler(compiler.DDLCompiler):
return text
-class OracleDefaultRunner(base.DefaultRunner):
- def visit_sequence(self, seq):
- return self.execute_string("SELECT " +
- self.dialect.identifier_preparer.format_sequence(seq) +
- ".nextval FROM DUAL", ())
-
class OracleIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = set([x.lower() for x in RESERVED_WORDS])
@@ -477,6 +471,13 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer):
name = re.sub(r'^_+', '', savepoint.ident)
return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
+
+class OracleExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq):
+ return self._execute_scalar("SELECT " +
+ self.dialect.identifier_preparer.format_sequence(seq) +
+ ".nextval FROM DUAL")
+
class OracleDialect(default.DefaultDialect):
name = 'oracle'
supports_alter = True
@@ -502,7 +503,7 @@ class OracleDialect(default.DefaultDialect):
ddl_compiler = OracleDDLCompiler
type_compiler = OracleTypeCompiler
preparer = OracleIdentifierPreparer
- defaultrunner = OracleDefaultRunner
+ execution_ctx_cls = OracleExecutionContext
reflection_options = ('oracle_resolve_synonyms', )
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 65f3cb928..6108d3d66 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -74,9 +74,8 @@ with this feature but it should be regarded as experimental.
"""
-from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, RESERVED_WORDS
+from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, RESERVED_WORDS, OracleExecutionContext
from sqlalchemy.dialects.oracle import base as oracle
-from sqlalchemy.engine.default import DefaultExecutionContext
from sqlalchemy.engine import base
from sqlalchemy import types as sqltypes, util
import datetime
@@ -200,7 +199,7 @@ class Oracle_cx_oracleCompiler(OracleCompiler):
else:
return OracleCompiler.bindparam_string(self, name)
-class Oracle_cx_oracleExecutionContext(DefaultExecutionContext):
+class Oracle_cx_oracleExecutionContext(OracleExecutionContext):
def pre_exec(self):
quoted_bind_names = getattr(self.compiled, '_quoted_bind_names', {})
if quoted_bind_names:
diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py
index 82aac080e..6edef301c 100644
--- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py
+++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py
@@ -12,7 +12,7 @@ import re
from sqlalchemy import sql, types as sqltypes, util
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
-from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect
+from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext
from sqlalchemy.engine import base, default
from sqlalchemy.sql import expression
@@ -71,7 +71,7 @@ class Oracle_zxjdbcCompiler(OracleCompiler):
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
-class Oracle_zxjdbcExecutionContext(default.DefaultExecutionContext):
+class Oracle_zxjdbcExecutionContext(OracleExecutionContext):
def pre_exec(self):
if hasattr(self.compiled, 'returning_parameters'):
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 19d9224e2..0bc5f08b0 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -332,40 +332,6 @@ class PGDDLCompiler(compiler.DDLCompiler):
return text
-class PGDefaultRunner(base.DefaultRunner):
-
- def get_column_default(self, column, isinsert=True):
- if column.primary_key:
- if (isinstance(column.server_default, schema.DefaultClause) and
- column.server_default.arg is not None):
-
- # pre-execute passive defaults on primary key columns
- return self.execute_string("select %s" % column.server_default.arg)
-
- elif column is column.table._autoincrement_column \
- and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
-
- # execute the sequence associated with a SERIAL primary key column.
- # for non-primary-key SERIAL, the ID just generates server side.
- sch = column.table.schema
-
- if sch is not None:
- exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
- else:
- exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
-
- return self.execute_string(exc)
-
- return super(PGDefaultRunner, self).get_column_default(column)
-
- def visit_sequence(self, seq):
- if not seq.optional:
- return self.execute_string(("select nextval('%s')" % \
- self.dialect.identifier_preparer.format_sequence(seq)))
- else:
- return None
-
-
class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_INET(self, type_):
return "INET"
@@ -438,6 +404,39 @@ class PGInspector(reflection.Inspector):
info_cache=self.info_cache)
+
+class PGExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq):
+ if not seq.optional:
+ return self._execute_scalar(("select nextval('%s')" % \
+ self.dialect.identifier_preparer.format_sequence(seq)))
+ else:
+ return None
+
+ def get_insert_default(self, column):
+ if column.primary_key:
+ if (isinstance(column.server_default, schema.DefaultClause) and
+ column.server_default.arg is not None):
+
+ # pre-execute passive defaults on primary key columns
+ return self._execute_scalar("select %s" % column.server_default.arg)
+
+ elif column is column.table._autoincrement_column \
+ and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+
+ # execute the sequence associated with a SERIAL primary key column.
+ # for non-primary-key SERIAL, the ID just generates server side.
+ sch = column.table.schema
+
+ if sch is not None:
+ exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
+ else:
+ exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
+
+ return self._execute_scalar(exc)
+
+ return super(PGExecutionContext, self).get_insert_default(column)
+
class PGDialect(default.DefaultDialect):
name = 'postgresql'
supports_alter = True
@@ -459,7 +458,7 @@ class PGDialect(default.DefaultDialect):
ddl_compiler = PGDDLCompiler
type_compiler = PGTypeCompiler
preparer = PGIdentifierPreparer
- defaultrunner = PGDefaultRunner
+ execution_ctx_cls = PGExecutionContext
inspector = PGInspector
isolation_level = None
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
index 1c45b50f2..17fe86be6 100644
--- a/lib/sqlalchemy/dialects/postgresql/pg8000.py
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -23,7 +23,7 @@ from sqlalchemy.engine import default
import decimal
from sqlalchemy import util
from sqlalchemy import types as sqltypes
-from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, PGIdentifierPreparer
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext
class _PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
@@ -41,7 +41,7 @@ class _PGNumeric(sqltypes.Numeric):
return process
-class PostgreSQL_pg8000ExecutionContext(default.DefaultExecutionContext):
+class PostgreSQL_pg8000ExecutionContext(PGExecutionContext):
pass
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index a09697e79..50e4bec3b 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -42,7 +42,7 @@ from sqlalchemy.engine import base, default
from sqlalchemy.sql import expression
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import types as sqltypes
-from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, PGIdentifierPreparer
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext
class _PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
@@ -65,7 +65,7 @@ SERVER_SIDE_CURSOR_RE = re.compile(
r'\s*SELECT',
re.I | re.UNICODE)
-class PostgreSQL_psycopg2ExecutionContext(default.DefaultExecutionContext):
+class PostgreSQL_psycopg2ExecutionContext(PGExecutionContext):
def create_cursor(self):
# TODO: coverage for server side cursors + select.for_update()
is_server_side = \
diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
index 975006d92..517d41aaf 100644
--- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
+++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
@@ -11,7 +11,7 @@ from sqlalchemy.engine import default
import decimal
from sqlalchemy import util
from sqlalchemy import types as sqltypes
-from sqlalchemy.dialects.postgresql.base import PGDialect, PGDefaultRunner
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
class PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
@@ -28,13 +28,9 @@ class PGNumeric(sqltypes.Numeric):
return value
return process
-class PostgreSQL_pypostgresqlExecutionContext(default.DefaultExecutionContext):
+class PostgreSQL_pypostgresqlExecutionContext(PGExecutionContext):
pass
-class PostgreSQL_pypostgresqlDefaultRunner(PGDefaultRunner):
- def execute_string(self, stmt, params=None):
- return PGDefaultRunner.execute_string(self, stmt, params or ())
-
class PostgreSQL_pypostgresql(PGDialect):
driver = 'pypostgresql'
@@ -43,8 +39,6 @@ class PostgreSQL_pypostgresql(PGDialect):
supports_unicode_binds = True
description_encoding = None
- defaultrunner = PostgreSQL_pypostgresqlDefaultRunner
-
default_paramstyle = 'format'
supports_sane_rowcount = False # alas....posting a bug now
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 694a2f71f..fed28c094 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -60,7 +60,6 @@ from sqlalchemy.engine.base import (
Compiled,
Connectable,
Connection,
- DefaultRunner,
Dialect,
Engine,
ExecutionContext,
@@ -83,7 +82,6 @@ __all__ = (
'Compiled',
'Connectable',
'Connection',
- 'DefaultRunner',
'Dialect',
'Engine',
'ExecutionContext',
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 26e44dd6b..829f97558 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -14,7 +14,7 @@ and result contexts.
__all__ = [
'BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy',
- 'Compiled', 'Connectable', 'Connection', 'DefaultRunner', 'Dialect', 'Engine',
+ 'Compiled', 'Connectable', 'Connection', 'Dialect', 'Engine',
'ExecutionContext', 'NestedTransaction', 'ResultProxy', 'RootTransaction',
'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction',
'connection_memoize']
@@ -57,10 +57,6 @@ class Dialect(object):
type of encoding to use for unicode, usually defaults to
'utf-8'.
- defaultrunner
- a :class:`~sqlalchemy.schema.SchemaVisitor` class which executes
- defaults.
-
statement_compiler
a :class:`~Compiled` class used to compile SQL statements
@@ -1012,9 +1008,8 @@ class Connection(Connectable):
return self._execute_clauseelement(func.select(), multiparams, params)
def _execute_default(self, default, multiparams, params):
- ret = self.engine.dialect.\
- defaultrunner(self.__create_execution_context()).\
- traverse_single(default)
+ ctx = self.__create_execution_context()
+ ret = ctx._exec_default(default)
if self.__close_with_result:
self.close()
return ret
@@ -2154,68 +2149,6 @@ class BufferedColumnResultProxy(ResultProxy):
return l
-class DefaultRunner(schema.SchemaVisitor):
- """A visitor which accepts ColumnDefault objects, produces the
- dialect-specific SQL corresponding to their execution, and
- executes the SQL, returning the result value.
-
- DefaultRunners are used internally by Engines and Dialects.
- Specific database modules should provide their own subclasses of
- DefaultRunner to allow database-specific behavior.
- """
-
- def __init__(self, context):
- self.context = context
- self.dialect = context.dialect
- self.cursor = context.cursor
-
- def get_column_default(self, column):
- if column.default is not None:
- return self.traverse_single(column.default)
- else:
- return None
-
- def get_column_onupdate(self, column):
- if column.onupdate is not None:
- return self.traverse_single(column.onupdate)
- else:
- return None
-
- def visit_passive_default(self, default):
- return None
-
- def visit_sequence(self, seq):
- return None
-
- def exec_default_sql(self, default):
- conn = self.context.connection
- c = expression.select([default.arg]).compile(bind=conn)
- return conn._execute_compiled(c, (), {}).scalar()
-
- def execute_string(self, stmt, params=None):
- """execute a string statement, using the raw cursor, and return a scalar result."""
-
- conn = self.context._connection
- if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
- stmt = stmt.encode(self.dialect.encoding)
- conn._cursor_execute(self.cursor, stmt, params)
- return self.cursor.fetchone()[0]
-
- def visit_column_onupdate(self, onupdate):
- if isinstance(onupdate.arg, expression.ClauseElement):
- return self.exec_default_sql(onupdate)
- elif util.callable(onupdate.arg):
- return onupdate.arg(self.context)
- else:
- return onupdate.arg
-
- def visit_column_default(self, default):
- if isinstance(default.arg, expression.ClauseElement):
- return self.exec_default_sql(default)
- elif util.callable(default.arg):
- return default.arg(self.context)
- else:
- return default.arg
def connection_memoize(key):
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index e062abce6..12ab605e4 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -28,7 +28,6 @@ class DefaultDialect(base.Dialect):
ddl_compiler = compiler.DDLCompiler
type_compiler = compiler.GenericTypeCompiler
preparer = compiler.IdentifierPreparer
- defaultrunner = base.DefaultRunner
supports_alter = True
supports_sequences = False
@@ -198,10 +197,7 @@ class DefaultExecutionContext(base.ExecutionContext):
self.result_map = None
self.cursor = self.create_cursor()
self.compiled_parameters = []
- if self.dialect.positional:
- self.parameters = [()]
- else:
- self.parameters = [{}]
+ self.parameters = [self._default_params]
elif compiled_sql is not None:
self.compiled = compiled = compiled_sql
@@ -273,6 +269,27 @@ class DefaultExecutionContext(base.ExecutionContext):
bool(self.compiled.returning) and \
not self.compiled.statement._returning
+ @util.memoized_property
+ def _default_params(self):
+ if self.dialect.positional:
+ return ()
+ else:
+ return {}
+
+ def _execute_scalar(self, stmt):
+ """Execute a string statement on the current cursor, returning a scalar result.
+
+ Used to fire off sequences, default phrases, and "select lastrowid" types of statements individually
+ or in the context of a parent INSERT or UPDATE statement.
+
+ """
+
+ conn = self._connection
+ if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
+ stmt = stmt.encode(self.dialect.encoding)
+ conn._cursor_execute(self.cursor, stmt, self._default_params)
+ return self.cursor.fetchone()[0]
+
@property
def connection(self):
return self._connection._branch()
@@ -286,10 +303,8 @@ class DefaultExecutionContext(base.ExecutionContext):
if self.dialect.positional or self.dialect.supports_unicode_statements:
if params:
return params
- elif self.dialect.positional:
- return [()]
else:
- return [{}]
+ return [self._default_params]
else:
def proc(d):
# sigh, sometimes we get positional arguments with a dialect
@@ -460,6 +475,32 @@ class DefaultExecutionContext(base.ExecutionContext):
self._connection._handle_dbapi_exception(e, None, None, None, self)
raise
+ def _exec_default(self, default):
+ if default.is_sequence:
+ return self.fire_sequence(default)
+ elif default.is_callable:
+ return default.arg(self)
+ elif default.is_clause_element:
+ # TODO: expensive branching here should be
+ # pulled into _exec_scalar()
+ conn = self.connection
+ c = expression.select([default.arg]).compile(bind=conn)
+ return conn._execute_compiled(c, (), {}).scalar()
+ else:
+ return default.arg
+
+ def get_insert_default(self, column):
+ if column.default is None:
+ return None
+ else:
+ return self._exec_default(column.default)
+
+ def get_update_default(self, column):
+ if column.onupdate is None:
+ return None
+ else:
+ return self._exec_default(column.onupdate)
+
def __process_defaults(self):
"""Generate default values for compiled insert/update statements,
and generate inserted_primary_key collection.
@@ -467,28 +508,26 @@ class DefaultExecutionContext(base.ExecutionContext):
if self.executemany:
if len(self.compiled.prefetch):
- drunner = self.dialect.defaultrunner(self)
params = self.compiled_parameters
for param in params:
self.current_parameters = param
for c in self.compiled.prefetch:
if self.isinsert:
- val = drunner.get_column_default(c)
+ val = self.get_insert_default(c)
else:
- val = drunner.get_column_onupdate(c)
+ val = self.get_update_default(c)
if val is not None:
param[c.key] = val
del self.current_parameters
else:
self.current_parameters = compiled_parameters = self.compiled_parameters[0]
- drunner = self.dialect.defaultrunner(self)
for c in self.compiled.prefetch:
if self.isinsert:
- val = drunner.get_column_default(c)
+ val = self.get_insert_default(c)
else:
- val = drunner.get_column_onupdate(c)
+ val = self.get_update_default(c)
if val is not None:
compiled_parameters[c.key] = val