summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/base.py')
-rw-r--r--lib/sqlalchemy/engine/base.py132
1 files changed, 52 insertions, 80 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index faeb00cc9..553c8df84 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -8,7 +8,8 @@
higher-level statement-construction, connection-management,
execution and result contexts."""
-from sqlalchemy import exceptions, sql, schema, util, types, logging
+from sqlalchemy import exceptions, schema, util, types, logging
+from sqlalchemy.sql import expression, visitors
import StringIO, sys
@@ -35,6 +36,22 @@ class Dialect(object):
encoding
type of encoding to use for unicode, usually defaults to 'utf-8'
+
+ schemagenerator
+ a [sqlalchemy.schema#SchemaVisitor] class which generates schemas.
+
+ schemadropper
+ a [sqlalchemy.schema#SchemaVisitor] class which drops schemas.
+
+ defaultrunner
+ a [sqlalchemy.schema#SchemaVisitor] class which executes defaults.
+
+ statement_compiler
+ a [sqlalchemy.engine.base#Compiled] class used to compile SQL statements
+
+ preparer
+ a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to quote
+ identifiers.
"""
def create_connect_args(self, url):
@@ -105,48 +122,6 @@ class Dialect(object):
raise NotImplementedError()
- def schemagenerator(self, connection, **kwargs):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can generate schemas.
-
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
-
- `schemagenerator()` is called via the `create()` method on Table,
- Index, and others.
- """
-
- raise NotImplementedError()
-
- def schemadropper(self, connection, **kwargs):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can drop schemas.
-
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
-
- `schemadropper()` is called via the `drop()` method on Table,
- Index, and others.
- """
-
- raise NotImplementedError()
-
- def defaultrunner(self, execution_context):
- """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults.
-
- execution_context
- a [sqlalchemy.engine#ExecutionContext] to use for statement execution
-
- """
-
- raise NotImplementedError()
-
- def compiler(self, statement, parameters):
- """Return a [sqlalchemy.sql#Compiled] object for the given statement/parameters.
-
- The returned object is usually a subclass of [sqlalchemy.ansisql#ANSICompiler].
-
- """
-
- raise NotImplementedError()
def server_version_info(self, connection):
"""Return a tuple of the database's version number."""
@@ -266,16 +241,6 @@ class Dialect(object):
raise NotImplementedError()
-
- def compile(self, clauseelement, parameters=None):
- """Compile the given [sqlalchemy.sql#ClauseElement] using this Dialect.
-
- Returns [sqlalchemy.sql#Compiled]. A convenience method which
- flips around the compile() call on ``ClauseElement``.
- """
-
- return clauseelement.compile(dialect=self, parameters=parameters)
-
def is_disconnect(self, e):
"""Return True if the given DBAPI error indicates an invalid connection"""
@@ -304,7 +269,7 @@ class ExecutionContext(object):
DBAPI cursor procured from the connection
compiled
- if passed to constructor, sql.Compiled object being executed
+ if passed to constructor, sqlalchemy.engine.base.Compiled object being executed
statement
string version of the statement to be executed. Is either
@@ -439,6 +404,9 @@ class Compiled(object):
def __init__(self, dialect, statement, parameters, bind=None):
"""Construct a new ``Compiled`` object.
+ dialect
+ ``Dialect`` to compile against.
+
statement
``ClauseElement`` to be compiled.
@@ -724,8 +692,8 @@ class Connection(Connectable):
def scalar(self, object, *multiparams, **params):
return self.execute(object, *multiparams, **params).scalar()
- def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
+ def statement_compiler(self, statement, parameters, **kwargs):
+ return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
@@ -822,9 +790,9 @@ class Connection(Connectable):
# poor man's multimethod/generic function thingy
executors = {
- sql._Function : _execute_function,
- sql.ClauseElement : _execute_clauseelement,
- sql.ClauseVisitor : _execute_compiled,
+ expression._Function : _execute_function,
+ expression.ClauseElement : _execute_clauseelement,
+ visitors.ClauseVisitor : _execute_compiled,
schema.SchemaItem:_execute_default,
str.__mro__[-2] : _execute_text
}
@@ -989,14 +957,14 @@ class Engine(Connectable):
connection.close()
def _func(self):
- return sql._FunctionGenerator(bind=self)
+ return expression._FunctionGenerator(bind=self)
func = property(_func)
def text(self, text, *args, **kwargs):
"""Return a sql.text() object for performing literal queries."""
- return sql.text(text, bind=self, *args, **kwargs)
+ return expression.text(text, bind=self, *args, **kwargs)
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
if connection is None:
@@ -1004,7 +972,7 @@ class Engine(Connectable):
else:
conn = connection
try:
- visitorcallable(conn, **kwargs).traverse(element)
+ visitorcallable(self.dialect, conn, **kwargs).traverse(element)
finally:
if connection is None:
conn.close()
@@ -1057,8 +1025,8 @@ class Engine(Connectable):
connection = self.contextual_connect(close_with_result=True)
return connection._execute_compiled(compiled, multiparams, params)
- def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
+ def statement_compiler(self, statement, parameters, **kwargs):
+ return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
def connect(self, **kwargs):
"""Return a newly allocated Connection object."""
@@ -1159,6 +1127,7 @@ class ResultProxy(object):
self.closed = False
self.cursor = context.cursor
self.__echo = logging.is_debug_enabled(context.engine.logger)
+ self._process_row = self._row_processor()
if context.is_select():
self._init_metadata()
self._rowcount = None
@@ -1222,7 +1191,7 @@ class ResultProxy(object):
rec = props[key]
elif isinstance(key, basestring) and key.lower() in props:
rec = props[key.lower()]
- elif isinstance(key, sql.ColumnElement):
+ elif isinstance(key, expression.ColumnElement):
label = context.column_labels.get(key._label, key.name).lower()
if label in props:
rec = props[label]
@@ -1320,21 +1289,21 @@ class ResultProxy(object):
return self.cursor.fetchmany(size)
def _fetchall_impl(self):
return self.cursor.fetchall()
+
+ def _row_processor(self):
+ return RowProxy
- def _process_row(self, row):
- return RowProxy(self, row)
-
def fetchall(self):
"""Fetch all rows, just like DBAPI ``cursor.fetchall()``."""
- l = [self._process_row(row) for row in self._fetchall_impl()]
+ l = [self._process_row(self, row) for row in self._fetchall_impl()]
self.close()
return l
def fetchmany(self, size=None):
"""Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``."""
- l = [self._process_row(row) for row in self._fetchmany_impl(size)]
+ l = [self._process_row(self, row) for row in self._fetchmany_impl(size)]
if len(l) == 0:
self.close()
return l
@@ -1343,7 +1312,7 @@ class ResultProxy(object):
"""Fetch one row, just like DBAPI ``cursor.fetchone()``."""
row = self._fetchone_impl()
if row is not None:
- return self._process_row(row)
+ return self._process_row(self, row)
else:
self.close()
return None
@@ -1353,7 +1322,7 @@ class ResultProxy(object):
row = self._fetchone_impl()
try:
if row is not None:
- return self._process_row(row)[0]
+ return self._process_row(self, row)[0]
else:
return None
finally:
@@ -1425,11 +1394,9 @@ class BufferedColumnResultProxy(ResultProxy):
def _get_col(self, row, key):
rec = self._key_cache[key]
return row[rec[2]]
-
- def _process_row(self, row):
- sup = super(BufferedColumnResultProxy, self)
- row = [sup._get_col(row, i) for i in xrange(len(row))]
- return RowProxy(self, row)
+
+ def _row_processor(self):
+ return BufferedColumnRow
def fetchall(self):
l = []
@@ -1523,6 +1490,11 @@ class RowProxy(object):
def __len__(self):
return len(self.__row)
+class BufferedColumnRow(RowProxy):
+ def __init__(self, parent, row):
+ row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))]
+ super(BufferedColumnRow, self).__init__(parent, row)
+
class SchemaIterator(schema.SchemaVisitor):
"""A visitor that can gather text into a buffer and execute the contents of the buffer."""
@@ -1590,11 +1562,11 @@ class DefaultRunner(schema.SchemaVisitor):
return None
def exec_default_sql(self, default):
- c = sql.select([default.arg]).compile(bind=self.connection)
+ c = expression.select([default.arg]).compile(bind=self.connection)
return self.connection._execute_compiled(c).scalar()
def visit_column_onupdate(self, onupdate):
- if isinstance(onupdate.arg, sql.ClauseElement):
+ if isinstance(onupdate.arg, expression.ClauseElement):
return self.exec_default_sql(onupdate)
elif callable(onupdate.arg):
return onupdate.arg(self.context)
@@ -1602,7 +1574,7 @@ class DefaultRunner(schema.SchemaVisitor):
return onupdate.arg
def visit_column_default(self, default):
- if isinstance(default.arg, sql.ClauseElement):
+ if isinstance(default.arg, expression.ClauseElement):
return self.exec_default_sql(default)
elif callable(default.arg):
return default.arg(self.context)