diff options
Diffstat (limited to 'lib/sqlalchemy/engine.py')
-rw-r--r-- | lib/sqlalchemy/engine.py | 129 |
1 files changed, 94 insertions, 35 deletions
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 22ccbd11c..81d72b17b 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -78,17 +78,21 @@ class SchemaIterator(schema.SchemaVisitor): self.buffer.truncate(0) class DefaultRunner(schema.SchemaVisitor): - def __init__(self, proxy): + def __init__(self, engine, proxy): self.proxy = proxy + self.engine = engine def visit_sequence(self, seq): """sequences are not supported by default""" return None + def exec_default_sql(self, default): + c = sql.select([default.arg], engine=self.engine).compile() + return self.proxy(str(c), c.get_params()).fetchone()[0] + def visit_column_default(self, default): - if isinstance(default.arg, ClauseElement): - c = default.arg.compile() - return proxy.execute(str(c), c.get_params()) + if isinstance(default.arg, sql.ClauseElement): + return self.exec_default_sql(default) elif callable(default.arg): return default.arg() else: @@ -113,11 +117,29 @@ class SQLEngine(schema.SchemaEngine): self.context = util.ThreadLocal(raiseerror=False) self.tables = {} self.notes = {} + self._figure_paramstyle() if logger is None: self.logger = sys.stdout else: self.logger = logger - + + def _figure_paramstyle(self): + db = self.dbapi() + if db is not None: + self.paramstyle = db.paramstyle + else: + self.paramstyle = 'named' + + if self.paramstyle == 'named': + self.bindtemplate = ':%s' + self.positional=False + elif self.paramstyle =='pyformat': + self.bindtemplate = "%%(%s)s" + self.positional=False + else: + # for positional, use pyformat until the end + self.bindtemplate = "%%(%s)s" + self.positional=True def type_descriptor(self, typeobj): if type(typeobj) is type: @@ -131,9 +153,9 @@ class SQLEngine(schema.SchemaEngine): raise NotImplementedError() def defaultrunner(self, proxy): - return DefaultRunner(proxy) + return DefaultRunner(self, proxy) - def compiler(self, statement, bindparams): + def compiler(self, statement, parameters): raise NotImplementedError() def rowid_column_name(self): @@ -152,11 +174,11 @@ class SQLEngine(schema.SchemaEngine): """drops a table given a schema.Table object.""" table.accept_visitor(self.schemadropper(self.proxy(), **params)) - def compile(self, statement, bindparams, **kwargs): + def compile(self, statement, parameters, **kwargs): """given a sql.ClauseElement statement plus optional bind parameters, creates a new instance of this engine's SQLCompiler, compiles the ClauseElement, and returns the newly compiled object.""" - compiler = self.compiler(statement, bindparams, **kwargs) + compiler = self.compiler(statement, parameters, **kwargs) statement.accept_visitor(compiler) compiler.after_compile() return compiler @@ -263,26 +285,15 @@ class SQLEngine(schema.SchemaEngine): self.context.transaction = None self.context.tcount = None - def _process_defaults(self, proxy, statement, parameters, compiled=None, **kwargs): + def _process_defaults(self, proxy, compiled, parameters, **kwargs): if compiled is None: return if getattr(compiled, "isinsert", False): - # TODO: this sucks. we have to get "parameters" to be a more standardized object - if isinstance(parameters, list) and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)): + if isinstance(parameters, list): plist = parameters else: plist = [parameters] - # inserts are usually one at a time. but if we got a list of parameters, - # it will calculate last_inserted_ids for just the last row in the list. - # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence - # it or post-select anyway drunner = self.defaultrunner(proxy) for param in plist: - # the parameters might be positional, so convert them - # to a dictionary - # TODO: this is stupid. or, is this stupid ? - # any way we can just have an OrderedDict so we have the - # dictionary + postional version each time ? - param = compiled.get_named_params(param) last_inserted_ids = [] need_lastrowid=False for c in compiled.statement.table.c: @@ -306,18 +317,18 @@ class SQLEngine(schema.SchemaEngine): self.context.last_inserted_ids = last_inserted_ids - def pre_exec(self, proxy, statement, parameters, **kwargs): + def pre_exec(self, proxy, compiled, parameters, **kwargs): pass - def post_exec(self, proxy, statement, parameters, **kwargs): + def post_exec(self, proxy, compiled, parameters, **kwargs): pass - def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs): + def execute_compiled(self, compiled, parameters, connection=None, cursor=None, echo=None, **kwargs): """executes the given string-based SQL statement with the given parameters. The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending on the paramstyle of the DBAPI. - + If the current thread has specified a transaction begin() for this engine, the statement will be executed in the context of the current transactional connection. Otherwise, a commit() will be performed immediately after execution, since the local @@ -352,6 +363,62 @@ class SQLEngine(schema.SchemaEngine): def proxy(statement=None, parameters=None): if statement is None: return cursor + + executemany = parameters is not None and isinstance(parameters, list) + + if self.positional: + if executemany: + parameters = [p.values() for p in parameters] + else: + parameters = parameters.values() + + self.execute(statement, parameters, connection=connection, cursor=cursor) + return cursor + + self.pre_exec(proxy, compiled, parameters, **kwargs) + self._process_defaults(proxy, compiled, parameters, **kwargs) + proxy(str(compiled), parameters) + self.post_exec(proxy, compiled, parameters, **kwargs) + return ResultProxy(cursor, self, typemap=compiled.typemap) + + def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs): + """executes the given string-based SQL statement with the given parameters. + + The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending + on the paramstyle of the DBAPI. + + If the current thread has specified a transaction begin() for this engine, the + statement will be executed in the context of the current transactional connection. + Otherwise, a commit() will be performed immediately after execution, since the local + pooled connection is returned to the pool after execution without a transaction set + up. + + In all error cases, a rollback() is immediately performed on the connection before + propigating the exception outwards. + + Other options include: + + connection - a DBAPI connection to use for the execute. If None, a connection is + pulled from this engine's connection pool. + + echo - enables echo for this execution, which causes all SQL and parameters + to be dumped to the engine's logging output before execution. + + typemap - a map of column names mapped to sqlalchemy.types.TypeEngine objects. + These will be passed to the created ResultProxy to perform + post-processing on result-set values. + + commit - if True, will automatically commit the statement after completion. """ + if parameters is None: + parameters = {} + + if connection is None: + connection = self.connection() + + if cursor is None: + cursor = connection.cursor() + + try: if echo is True or self.echo is not False: self.log(statement) self.log(repr(parameters)) @@ -359,18 +426,10 @@ class SQLEngine(schema.SchemaEngine): self._executemany(cursor, statement, parameters) else: self._execute(cursor, statement, parameters) - return cursor - - try: - self.pre_exec(proxy, statement, parameters, **kwargs) - self._process_defaults(proxy, statement, parameters, **kwargs) - proxy(statement, parameters) - self.post_exec(proxy, statement, parameters, **kwargs) - if commit or self.context.transaction is None: + if self.context.transaction is None: self.do_commit(connection) except: self.do_rollback(connection) - # TODO: wrap DB exceptions ? raise return ResultProxy(cursor, self, typemap = typemap) |