summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py147
1 files changed, 82 insertions, 65 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 95f6566e3..962e2ab60 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -4,25 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Provide default implementations of per-dialect sqlalchemy.engine classes"""
-from sqlalchemy import schema, exceptions, util, sql, types
-import StringIO, sys, re
+from sqlalchemy import schema, exceptions, sql, types
+import sys, re
from sqlalchemy.engine import base
-"""Provide default implementations of the engine interfaces"""
-class PoolConnectionProvider(base.ConnectionProvider):
- def __init__(self, url, pool):
- self.url = url
- self._pool = pool
-
- def get_connection(self):
- return self._pool.connect()
-
- def dispose(self):
- self._pool.dispose()
- self._pool = self._pool.recreate()
-
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
@@ -33,7 +21,18 @@ class DefaultDialect(base.Dialect):
self._ischema = None
self.dbapi = dbapi
self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
-
+
+ def decode_result_columnname(self, name):
+ """decode a name found in cursor.description to a unicode object."""
+
+ return name.decode(self.encoding)
+
+ def dbapi_type_map(self):
+ # most DBAPIs have problems with this (such as, psycocpg2 types
+ # are unhashable). So far Oracle can return it.
+
+ return {}
+
def create_execution_context(self, **kwargs):
return DefaultExecutionContext(self, **kwargs)
@@ -88,6 +87,15 @@ class DefaultDialect(base.Dialect):
#print "ENGINE COMMIT ON ", connection.connection
connection.commit()
+
+ def do_savepoint(self, connection, name):
+ connection.execute(sql.SavepointClause(name))
+
+ def do_rollback_to_savepoint(self, connection, name):
+ connection.execute(sql.RollbackToSavepointClause(name))
+
+ def do_release_savepoint(self, connection, name):
+ connection.execute(sql.ReleaseSavepointClause(name))
def do_executemany(self, cursor, statement, parameters, **kwargs):
cursor.executemany(statement, parameters)
@@ -95,8 +103,8 @@ class DefaultDialect(base.Dialect):
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
- def defaultrunner(self, connection):
- return base.DefaultRunner(connection)
+ def defaultrunner(self, context):
+ return base.DefaultRunner(context)
def is_disconnect(self, e):
return False
@@ -107,23 +115,6 @@ class DefaultDialect(base.Dialect):
paramstyle = property(lambda s:s._paramstyle, _set_paramstyle)
- def convert_compiled_params(self, parameters):
- executemany = parameters is not None and isinstance(parameters, list)
- # the bind params are a CompiledParams object. but all the DBAPI's hate
- # that object (or similar). so convert it to a clean
- # dictionary/list/tuple of dictionary/tuple of list
- if parameters is not None:
- if self.positional:
- if executemany:
- parameters = [p.get_raw_list() for p in parameters]
- else:
- parameters = parameters.get_raw_list()
- else:
- if executemany:
- parameters = [p.get_raw_dict() for p in parameters]
- else:
- parameters = parameters.get_raw_dict()
- return parameters
def _figure_paramstyle(self, paramstyle=None, default='named'):
if paramstyle is not None:
@@ -152,29 +143,38 @@ class DefaultDialect(base.Dialect):
ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
class DefaultExecutionContext(base.ExecutionContext):
- def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
+ def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
self.dialect = dialect
self.connection = connection
self.compiled = compiled
- self.compiled_parameters = compiled_parameters
if compiled is not None:
self.typemap = compiled.typemap
self.column_labels = compiled.column_labels
self.statement = unicode(compiled)
- else:
+ if parameters is None:
+ self.compiled_parameters = compiled.construct_params({})
+ elif not isinstance(parameters, (list, tuple)):
+ self.compiled_parameters = compiled.construct_params(parameters)
+ else:
+ self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters]
+ if len(self.compiled_parameters) == 1:
+ self.compiled_parameters = self.compiled_parameters[0]
+ elif statement is not None:
self.typemap = self.column_labels = None
- self.parameters = self._encode_param_keys(parameters)
+ self.parameters = self.__encode_param_keys(parameters)
self.statement = statement
-
- if not dialect.supports_unicode_statements():
+ else:
+ self.statement = None
+
+ if self.statement is not None and not dialect.supports_unicode_statements():
self.statement = self.statement.encode(self.dialect.encoding)
self.cursor = self.create_cursor()
engine = property(lambda s:s.connection.engine)
- def _encode_param_keys(self, params):
+ def __encode_param_keys(self, params):
"""apply string encoding to the keys of dictionary-based bind parameters"""
if self.dialect.positional or self.dialect.supports_unicode_statements():
return params
@@ -189,16 +189,46 @@ class DefaultExecutionContext(base.ExecutionContext):
return [proc(d) for d in params]
else:
return proc(params)
+
+ def __convert_compiled_params(self, parameters):
+ executemany = parameters is not None and isinstance(parameters, list)
+ encode = not self.dialect.supports_unicode_statements()
+ # the bind params are a CompiledParams object. but all the DBAPI's hate
+ # that object (or similar). so convert it to a clean
+ # dictionary/list/tuple of dictionary/tuple of list
+ if parameters is not None:
+ if self.dialect.positional:
+ if executemany:
+ parameters = [p.get_raw_list() for p in parameters]
+ else:
+ parameters = parameters.get_raw_list()
+ else:
+ if executemany:
+ parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters]
+ else:
+ parameters = parameters.get_raw_dict(encode_keys=encode)
+ return parameters
def is_select(self):
- return re.match(r'SELECT', self.statement.lstrip(), re.I)
+ """return TRUE if the statement is expected to have result rows."""
+
+ return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None
def create_cursor(self):
return self.connection.connection.cursor()
-
+
+ def pre_execution(self):
+ self.pre_exec()
+
+ def post_execution(self):
+ self.post_exec()
+
+ def result(self):
+ return self.get_result_proxy()
+
def pre_exec(self):
self._process_defaults()
- self.parameters = self._encode_param_keys(self.dialect.convert_compiled_params(self.compiled_parameters))
+ self.parameters = self.__convert_compiled_params(self.compiled_parameters)
def post_exec(self):
pass
@@ -241,7 +271,7 @@ class DefaultExecutionContext(base.ExecutionContext):
inputsizes = []
for params in plist[0:1]:
for key in params.positional:
- typeengine = params.binds[key].type
+ typeengine = params.get_type(key)
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes.append(dbtype)
@@ -250,36 +280,23 @@ class DefaultExecutionContext(base.ExecutionContext):
inputsizes = {}
for params in plist[0:1]:
for key in params.keys():
- typeengine = params.binds[key].type
+ typeengine = params.get_type(key)
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes[key] = dbtype
self.cursor.setinputsizes(**inputsizes)
def _process_defaults(self):
- """``INSERT`` and ``UPDATE`` statements, when compiled, may
- have additional columns added to their ``VALUES`` and ``SET``
- lists corresponding to column defaults/onupdates that are
- present on the ``Table`` object (i.e. ``ColumnDefault``,
- ``Sequence``, ``PassiveDefault``). This method pre-execs
- those ``DefaultGenerator`` objects that require pre-execution
- and sets their values within the parameter list, and flags this
- ExecutionContext about ``PassiveDefault`` objects that may
- require post-fetching the row after it is inserted/updated.
-
- This method relies upon logic within the ``ANSISQLCompiler``
- in its `visit_insert` and `visit_update` methods that add the
- appropriate column clauses to the statement when its being
- compiled, so that these parameters can be bound to the
- statement.
- """
+ """generate default values for compiled insert/update statements,
+ and generate last_inserted_ids() collection."""
+ # TODO: cleanup
if self.compiled.isinsert:
if isinstance(self.compiled_parameters, list):
plist = self.compiled_parameters
else:
plist = [self.compiled_parameters]
- drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+ drunner = self.dialect.defaultrunner(self)
self._lastrow_has_defaults = False
for param in plist:
last_inserted_ids = []
@@ -319,7 +336,7 @@ class DefaultExecutionContext(base.ExecutionContext):
plist = self.compiled_parameters
else:
plist = [self.compiled_parameters]
- drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+ drunner = self.dialect.defaultrunner(self)
self._lastrow_has_defaults = False
for param in plist:
# check the "onupdate" status of each column in the table