summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine.py')
-rw-r--r--lib/sqlalchemy/engine.py84
1 files changed, 59 insertions, 25 deletions
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py
index 977b4c427..22ccbd11c 100644
--- a/lib/sqlalchemy/engine.py
+++ b/lib/sqlalchemy/engine.py
@@ -77,6 +77,24 @@ class SchemaIterator(schema.SchemaVisitor):
finally:
self.buffer.truncate(0)
+class DefaultRunner(schema.SchemaVisitor):
+ def __init__(self, proxy):
+ self.proxy = proxy
+
+ def visit_sequence(self, seq):
+ """sequences are not supported by default"""
+ return None
+
+ def visit_column_default(self, default):
+ if isinstance(default.arg, ClauseElement):
+ c = default.arg.compile()
+ return proxy.execute(str(c), c.get_params())
+ elif callable(default.arg):
+ return default.arg()
+ else:
+ return default.arg
+
+
class SQLEngine(schema.SchemaEngine):
"""base class for a series of database-specific engines. serves as an abstract factory
for implementation objects as well as database connections, transactions, SQL generators,
@@ -112,6 +130,9 @@ class SQLEngine(schema.SchemaEngine):
def schemadropper(self, proxy, **params):
raise NotImplementedError()
+ def defaultrunner(self, proxy):
+ return DefaultRunner(proxy)
+
def compiler(self, statement, bindparams):
raise NotImplementedError()
@@ -242,11 +263,11 @@ class SQLEngine(schema.SchemaEngine):
self.context.transaction = None
self.context.tcount = None
-
- def _process_sequences(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
+ def _process_defaults(self, proxy, statement, parameters, compiled=None, **kwargs):
if compiled is None: return
if getattr(compiled, "isinsert", False):
- if isinstance(parameters, list):
+ # TODO: this sucks. we have to get "parameters" to be a more standardized object
+ if isinstance(parameters, list) and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
plist = parameters
else:
plist = [parameters]
@@ -254,13 +275,20 @@ class SQLEngine(schema.SchemaEngine):
# it will calculate last_inserted_ids for just the last row in the list.
# TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence
# it or post-select anyway
+ drunner = self.defaultrunner(proxy)
for param in plist:
+ # the parameters might be positional, so convert them
+ # to a dictionary
+ # TODO: this is stupid. or, is this stupid ?
+ # any way we can just have an OrderedDict so we have the
+ # dictionary + postional version each time ?
+ param = compiled.get_named_params(param)
last_inserted_ids = []
need_lastrowid=False
for c in compiled.statement.table.c:
if not param.has_key(c.key) or param[c.key] is None:
- if c.sequence is not None:
- newid = self.exec_sequence(c.sequence)
+ if c.default is not None:
+ newid = c.default.accept_visitor(drunner)
else:
newid = None
@@ -276,19 +304,20 @@ class SQLEngine(schema.SchemaEngine):
self.context.last_inserted_ids = None
else:
self.context.last_inserted_ids = last_inserted_ids
-
- def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
+
+
+ def pre_exec(self, proxy, statement, parameters, **kwargs):
pass
- def post_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
+ def post_exec(self, proxy, statement, parameters, **kwargs):
pass
- def execute(self, statement, parameters, connection = None, echo = None, typemap = None, commit=False, **kwargs):
- """executes the given string-based SQL statement with the given parameters. This is
- a direct interface to a DBAPI connection object. The parameters may be a dictionary,
- or an array of dictionaries. If an array of dictionaries is sent, executemany will
- be performed on the cursor instead of execute.
+ def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs):
+ """executes the given string-based SQL statement with the given parameters.
+ The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
+ on the paramstyle of the DBAPI.
+
If the current thread has specified a transaction begin() for this engine, the
statement will be executed in the context of the current transactional connection.
Otherwise, a commit() will be performed immediately after execution, since the local
@@ -316,29 +345,34 @@ class SQLEngine(schema.SchemaEngine):
if connection is None:
connection = self.connection()
- c = connection.cursor()
- else:
- c = connection.cursor()
- try:
- self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs)
- #self._process_sequences(connection, c, statement, parameters, echo = echo, **kwargs)
-
+ if cursor is None:
+ cursor = connection.cursor()
+
+ def proxy(statement=None, parameters=None):
+ if statement is None:
+ return cursor
if echo is True or self.echo is not False:
self.log(statement)
self.log(repr(parameters))
- if isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
- self._executemany(c, statement, parameters)
+ if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
+ self._executemany(cursor, statement, parameters)
else:
- self._execute(c, statement, parameters)
- self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs)
+ self._execute(cursor, statement, parameters)
+ return cursor
+
+ try:
+ self.pre_exec(proxy, statement, parameters, **kwargs)
+ self._process_defaults(proxy, statement, parameters, **kwargs)
+ proxy(statement, parameters)
+ self.post_exec(proxy, statement, parameters, **kwargs)
if commit or self.context.transaction is None:
self.do_commit(connection)
except:
self.do_rollback(connection)
# TODO: wrap DB exceptions ?
raise
- return ResultProxy(c, self, typemap = typemap)
+ return ResultProxy(cursor, self, typemap = typemap)
def _execute(self, c, statement, parameters):
c.execute(statement, parameters)