diff options
Diffstat (limited to 'lib/sqlalchemy/engine/base.py')
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 113 |
1 files changed, 59 insertions, 54 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index cf0d35035..c154a1d68 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -255,6 +255,9 @@ class Dialect(sql.AbstractDialect): class ExecutionContext(object): """A messenger object for a Dialect that corresponds to a single execution. + ExecutionContext should have a datamember "cursor" which is created + at initialization time. + The Dialect should provide an ExecutionContext via the create_execution_context() method. The `pre_exec` and `post_exec` methods will be called for compiled statements, afterwhich it is @@ -263,7 +266,7 @@ class ExecutionContext(object): applicable. """ - def pre_exec(self, engine, proxy, compiled, parameters): + def pre_exec(self): """Called before an execution of a compiled statement. `proxy` is a callable that takes a string statement and a bind @@ -272,7 +275,7 @@ class ExecutionContext(object): raise NotImplementedError() - def post_exec(self, engine, proxy, compiled, parameters): + def post_exec(self): """Called after the execution of a compiled statement. `proxy` is a callable that takes a string statement and a bind @@ -281,7 +284,11 @@ class ExecutionContext(object): raise NotImplementedError() - def get_rowcount(self, cursor): + def get_result_proxy(self): + """return a ResultProxy corresponding to this ExecutionContext.""" + raise NotImplementedError() + + def get_rowcount(self): """Return the count of rows updated/deleted for an UPDATE/DELETE statement.""" raise NotImplementedError() @@ -497,68 +504,32 @@ class Connection(Connectable): """Execute a sql.Compiled object.""" if not compiled.can_execute: raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled))) - cursor = self.__engine.dialect.create_cursor(self.connection) parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)] if len(parameters) == 1: parameters = parameters[0] - def proxy(statement=None, parameters=None): - if statement is None: - return cursor - - parameters = self.__engine.dialect.convert_compiled_params(parameters) - self._execute_raw(statement, parameters, cursor=cursor, context=context) - return cursor - context = self.__engine.dialect.create_execution_context() - context.pre_exec(self.__engine, proxy, compiled, parameters) - proxy(unicode(compiled), parameters) - context.post_exec(self.__engine, proxy, compiled, parameters) - rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor) - return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs) - - # poor man's multimethod/generic function thingy - executors = { - sql._Function : execute_function, - sql.ClauseElement : execute_clauseelement, - sql.ClauseVisitor : execute_compiled, - schema.SchemaItem:execute_default, - str.__mro__[-2] : execute_text - } - - def create(self, entity, **kwargs): - """Create a table or index given an appropriate schema object.""" - - return self.__engine.create(entity, connection=self, **kwargs) - - def drop(self, entity, **kwargs): - """Drop a table or index given an appropriate schema object.""" - - return self.__engine.drop(entity, connection=self, **kwargs) - - def reflecttable(self, table, **kwargs): - """Reflect the columns in the given table from the database.""" - - return self.__engine.reflecttable(table, connection=self, **kwargs) - - def default_schema_name(self): - return self.__engine.dialect.get_default_schema_name(self) - - def run_callable(self, callable_): - return callable_(self) - - def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs): - if cursor is None: - cursor = self.__engine.dialect.create_cursor(self.connection) + context = self.__engine.dialect.create_execution_context(compiled=compiled, parameters=parameters, connection=self, engine=self.__engine) + context.pre_exec() + self.execute_compiled_impl(compiled, parameters, context) + context.post_exec() + return context.get_result_proxy() + + def _execute_compiled_impl(self, compiled, parameters, context): + self._execute_raw(unicode(compiled), self.__engine.dialect.convert_compiled_params(parameters), context=context) + + def _execute_raw(self, statement, parameters=None, context=None, **kwargs): if not self.__engine.dialect.supports_unicode_statements(): # encode to ascii, with full error handling statement = statement.encode('ascii') + if context is None: + context = self.__engine.dialect.create_execution_context(statement=statement, parameters=parameters, connection=self, engine=self.__engine) self.__engine.logger.info(statement) self.__engine.logger.info(repr(parameters)) if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)): - self._executemany(cursor, statement, parameters, context=context) + self._executemany(context.cursor, statement, parameters, context=context) else: - self._execute(cursor, statement, parameters, context=context) + self._execute(context.cursor, statement, parameters, context=context) self._autocommit(statement) - return cursor + return context.cursor def _execute(self, c, statement, parameters, context=None): if parameters is None: @@ -585,6 +556,40 @@ class Connection(Connectable): self.close() raise exceptions.SQLError(statement, parameters, e) + + + + # poor man's multimethod/generic function thingy + executors = { + sql._Function : execute_function, + sql.ClauseElement : execute_clauseelement, + sql.ClauseVisitor : execute_compiled, + schema.SchemaItem:execute_default, + str.__mro__[-2] : execute_text + } + + def create(self, entity, **kwargs): + """Create a table or index given an appropriate schema object.""" + + return self.__engine.create(entity, connection=self, **kwargs) + + def drop(self, entity, **kwargs): + """Drop a table or index given an appropriate schema object.""" + + return self.__engine.drop(entity, connection=self, **kwargs) + + def reflecttable(self, table, **kwargs): + """Reflect the columns in the given table from the database.""" + + return self.__engine.reflecttable(table, connection=self, **kwargs) + + def default_schema_name(self): + return self.__engine.dialect.get_default_schema_name(self) + + def run_callable(self, callable_): + return callable_(self) + + def proxy(self, statement=None, parameters=None): """Execute the given statement string and parameter object. |