diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 147 |
1 files changed, 82 insertions, 65 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 95f6566e3..962e2ab60 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -4,25 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Provide default implementations of per-dialect sqlalchemy.engine classes""" -from sqlalchemy import schema, exceptions, util, sql, types -import StringIO, sys, re +from sqlalchemy import schema, exceptions, sql, types +import sys, re from sqlalchemy.engine import base -"""Provide default implementations of the engine interfaces""" -class PoolConnectionProvider(base.ConnectionProvider): - def __init__(self, url, pool): - self.url = url - self._pool = pool - - def get_connection(self): - return self._pool.connect() - - def dispose(self): - self._pool.dispose() - self._pool = self._pool.recreate() - class DefaultDialect(base.Dialect): """Default implementation of Dialect""" @@ -33,7 +21,18 @@ class DefaultDialect(base.Dialect): self._ischema = None self.dbapi = dbapi self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle) - + + def decode_result_columnname(self, name): + """decode a name found in cursor.description to a unicode object.""" + + return name.decode(self.encoding) + + def dbapi_type_map(self): + # most DBAPIs have problems with this (such as, psycocpg2 types + # are unhashable). So far Oracle can return it. + + return {} + def create_execution_context(self, **kwargs): return DefaultExecutionContext(self, **kwargs) @@ -88,6 +87,15 @@ class DefaultDialect(base.Dialect): #print "ENGINE COMMIT ON ", connection.connection connection.commit() + + def do_savepoint(self, connection, name): + connection.execute(sql.SavepointClause(name)) + + def do_rollback_to_savepoint(self, connection, name): + connection.execute(sql.RollbackToSavepointClause(name)) + + def do_release_savepoint(self, connection, name): + connection.execute(sql.ReleaseSavepointClause(name)) def do_executemany(self, cursor, statement, parameters, **kwargs): cursor.executemany(statement, parameters) @@ -95,8 +103,8 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters) - def defaultrunner(self, connection): - return base.DefaultRunner(connection) + def defaultrunner(self, context): + return base.DefaultRunner(context) def is_disconnect(self, e): return False @@ -107,23 +115,6 @@ class DefaultDialect(base.Dialect): paramstyle = property(lambda s:s._paramstyle, _set_paramstyle) - def convert_compiled_params(self, parameters): - executemany = parameters is not None and isinstance(parameters, list) - # the bind params are a CompiledParams object. but all the DBAPI's hate - # that object (or similar). so convert it to a clean - # dictionary/list/tuple of dictionary/tuple of list - if parameters is not None: - if self.positional: - if executemany: - parameters = [p.get_raw_list() for p in parameters] - else: - parameters = parameters.get_raw_list() - else: - if executemany: - parameters = [p.get_raw_dict() for p in parameters] - else: - parameters = parameters.get_raw_dict() - return parameters def _figure_paramstyle(self, paramstyle=None, default='named'): if paramstyle is not None: @@ -152,29 +143,38 @@ class DefaultDialect(base.Dialect): ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None): + def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): self.dialect = dialect self.connection = connection self.compiled = compiled - self.compiled_parameters = compiled_parameters if compiled is not None: self.typemap = compiled.typemap self.column_labels = compiled.column_labels self.statement = unicode(compiled) - else: + if parameters is None: + self.compiled_parameters = compiled.construct_params({}) + elif not isinstance(parameters, (list, tuple)): + self.compiled_parameters = compiled.construct_params(parameters) + else: + self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters] + if len(self.compiled_parameters) == 1: + self.compiled_parameters = self.compiled_parameters[0] + elif statement is not None: self.typemap = self.column_labels = None - self.parameters = self._encode_param_keys(parameters) + self.parameters = self.__encode_param_keys(parameters) self.statement = statement - - if not dialect.supports_unicode_statements(): + else: + self.statement = None + + if self.statement is not None and not dialect.supports_unicode_statements(): self.statement = self.statement.encode(self.dialect.encoding) self.cursor = self.create_cursor() engine = property(lambda s:s.connection.engine) - def _encode_param_keys(self, params): + def __encode_param_keys(self, params): """apply string encoding to the keys of dictionary-based bind parameters""" if self.dialect.positional or self.dialect.supports_unicode_statements(): return params @@ -189,16 +189,46 @@ class DefaultExecutionContext(base.ExecutionContext): return [proc(d) for d in params] else: return proc(params) + + def __convert_compiled_params(self, parameters): + executemany = parameters is not None and isinstance(parameters, list) + encode = not self.dialect.supports_unicode_statements() + # the bind params are a CompiledParams object. but all the DBAPI's hate + # that object (or similar). so convert it to a clean + # dictionary/list/tuple of dictionary/tuple of list + if parameters is not None: + if self.dialect.positional: + if executemany: + parameters = [p.get_raw_list() for p in parameters] + else: + parameters = parameters.get_raw_list() + else: + if executemany: + parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters] + else: + parameters = parameters.get_raw_dict(encode_keys=encode) + return parameters def is_select(self): - return re.match(r'SELECT', self.statement.lstrip(), re.I) + """return TRUE if the statement is expected to have result rows.""" + + return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None def create_cursor(self): return self.connection.connection.cursor() - + + def pre_execution(self): + self.pre_exec() + + def post_execution(self): + self.post_exec() + + def result(self): + return self.get_result_proxy() + def pre_exec(self): self._process_defaults() - self.parameters = self._encode_param_keys(self.dialect.convert_compiled_params(self.compiled_parameters)) + self.parameters = self.__convert_compiled_params(self.compiled_parameters) def post_exec(self): pass @@ -241,7 +271,7 @@ class DefaultExecutionContext(base.ExecutionContext): inputsizes = [] for params in plist[0:1]: for key in params.positional: - typeengine = params.binds[key].type + typeengine = params.get_type(key) dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes.append(dbtype) @@ -250,36 +280,23 @@ class DefaultExecutionContext(base.ExecutionContext): inputsizes = {} for params in plist[0:1]: for key in params.keys(): - typeengine = params.binds[key].type + typeengine = params.get_type(key) dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes[key] = dbtype self.cursor.setinputsizes(**inputsizes) def _process_defaults(self): - """``INSERT`` and ``UPDATE`` statements, when compiled, may - have additional columns added to their ``VALUES`` and ``SET`` - lists corresponding to column defaults/onupdates that are - present on the ``Table`` object (i.e. ``ColumnDefault``, - ``Sequence``, ``PassiveDefault``). This method pre-execs - those ``DefaultGenerator`` objects that require pre-execution - and sets their values within the parameter list, and flags this - ExecutionContext about ``PassiveDefault`` objects that may - require post-fetching the row after it is inserted/updated. - - This method relies upon logic within the ``ANSISQLCompiler`` - in its `visit_insert` and `visit_update` methods that add the - appropriate column clauses to the statement when its being - compiled, so that these parameters can be bound to the - statement. - """ + """generate default values for compiled insert/update statements, + and generate last_inserted_ids() collection.""" + # TODO: cleanup if self.compiled.isinsert: if isinstance(self.compiled_parameters, list): plist = self.compiled_parameters else: plist = [self.compiled_parameters] - drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) + drunner = self.dialect.defaultrunner(self) self._lastrow_has_defaults = False for param in plist: last_inserted_ids = [] @@ -319,7 +336,7 @@ class DefaultExecutionContext(base.ExecutionContext): plist = self.compiled_parameters else: plist = [self.compiled_parameters] - drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) + drunner = self.dialect.defaultrunner(self) self._lastrow_has_defaults = False for param in plist: # check the "onupdate" status of each column in the table |