diff options
Diffstat (limited to 'lib/sqlalchemy/engine/base.py')
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 132 |
1 files changed, 52 insertions, 80 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index faeb00cc9..553c8df84 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -8,7 +8,8 @@ higher-level statement-construction, connection-management, execution and result contexts.""" -from sqlalchemy import exceptions, sql, schema, util, types, logging +from sqlalchemy import exceptions, schema, util, types, logging +from sqlalchemy.sql import expression, visitors import StringIO, sys @@ -35,6 +36,22 @@ class Dialect(object): encoding type of encoding to use for unicode, usually defaults to 'utf-8' + + schemagenerator + a [sqlalchemy.schema#SchemaVisitor] class which generates schemas. + + schemadropper + a [sqlalchemy.schema#SchemaVisitor] class which drops schemas. + + defaultrunner + a [sqlalchemy.schema#SchemaVisitor] class which executes defaults. + + statement_compiler + a [sqlalchemy.engine.base#Compiled] class used to compile SQL statements + + preparer + a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to quote + identifiers. """ def create_connect_args(self, url): @@ -105,48 +122,6 @@ class Dialect(object): raise NotImplementedError() - def schemagenerator(self, connection, **kwargs): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can generate schemas. - - connection - a [sqlalchemy.engine#Connection] to use for statement execution - - `schemagenerator()` is called via the `create()` method on Table, - Index, and others. - """ - - raise NotImplementedError() - - def schemadropper(self, connection, **kwargs): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can drop schemas. - - connection - a [sqlalchemy.engine#Connection] to use for statement execution - - `schemadropper()` is called via the `drop()` method on Table, - Index, and others. - """ - - raise NotImplementedError() - - def defaultrunner(self, execution_context): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults. - - execution_context - a [sqlalchemy.engine#ExecutionContext] to use for statement execution - - """ - - raise NotImplementedError() - - def compiler(self, statement, parameters): - """Return a [sqlalchemy.sql#Compiled] object for the given statement/parameters. - - The returned object is usually a subclass of [sqlalchemy.ansisql#ANSICompiler]. - - """ - - raise NotImplementedError() def server_version_info(self, connection): """Return a tuple of the database's version number.""" @@ -266,16 +241,6 @@ class Dialect(object): raise NotImplementedError() - - def compile(self, clauseelement, parameters=None): - """Compile the given [sqlalchemy.sql#ClauseElement] using this Dialect. - - Returns [sqlalchemy.sql#Compiled]. A convenience method which - flips around the compile() call on ``ClauseElement``. - """ - - return clauseelement.compile(dialect=self, parameters=parameters) - def is_disconnect(self, e): """Return True if the given DBAPI error indicates an invalid connection""" @@ -304,7 +269,7 @@ class ExecutionContext(object): DBAPI cursor procured from the connection compiled - if passed to constructor, sql.Compiled object being executed + if passed to constructor, sqlalchemy.engine.base.Compiled object being executed statement string version of the statement to be executed. Is either @@ -439,6 +404,9 @@ class Compiled(object): def __init__(self, dialect, statement, parameters, bind=None): """Construct a new ``Compiled`` object. + dialect + ``Dialect`` to compile against. + statement ``ClauseElement`` to be compiled. @@ -724,8 +692,8 @@ class Connection(Connectable): def scalar(self, object, *multiparams, **params): return self.execute(object, *multiparams, **params).scalar() - def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, parameters, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs) def execute(self, object, *multiparams, **params): for c in type(object).__mro__: @@ -822,9 +790,9 @@ class Connection(Connectable): # poor man's multimethod/generic function thingy executors = { - sql._Function : _execute_function, - sql.ClauseElement : _execute_clauseelement, - sql.ClauseVisitor : _execute_compiled, + expression._Function : _execute_function, + expression.ClauseElement : _execute_clauseelement, + visitors.ClauseVisitor : _execute_compiled, schema.SchemaItem:_execute_default, str.__mro__[-2] : _execute_text } @@ -989,14 +957,14 @@ class Engine(Connectable): connection.close() def _func(self): - return sql._FunctionGenerator(bind=self) + return expression._FunctionGenerator(bind=self) func = property(_func) def text(self, text, *args, **kwargs): """Return a sql.text() object for performing literal queries.""" - return sql.text(text, bind=self, *args, **kwargs) + return expression.text(text, bind=self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: @@ -1004,7 +972,7 @@ class Engine(Connectable): else: conn = connection try: - visitorcallable(conn, **kwargs).traverse(element) + visitorcallable(self.dialect, conn, **kwargs).traverse(element) finally: if connection is None: conn.close() @@ -1057,8 +1025,8 @@ class Engine(Connectable): connection = self.contextual_connect(close_with_result=True) return connection._execute_compiled(compiled, multiparams, params) - def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, parameters, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs) def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -1159,6 +1127,7 @@ class ResultProxy(object): self.closed = False self.cursor = context.cursor self.__echo = logging.is_debug_enabled(context.engine.logger) + self._process_row = self._row_processor() if context.is_select(): self._init_metadata() self._rowcount = None @@ -1222,7 +1191,7 @@ class ResultProxy(object): rec = props[key] elif isinstance(key, basestring) and key.lower() in props: rec = props[key.lower()] - elif isinstance(key, sql.ColumnElement): + elif isinstance(key, expression.ColumnElement): label = context.column_labels.get(key._label, key.name).lower() if label in props: rec = props[label] @@ -1320,21 +1289,21 @@ class ResultProxy(object): return self.cursor.fetchmany(size) def _fetchall_impl(self): return self.cursor.fetchall() + + def _row_processor(self): + return RowProxy - def _process_row(self, row): - return RowProxy(self, row) - def fetchall(self): """Fetch all rows, just like DBAPI ``cursor.fetchall()``.""" - l = [self._process_row(row) for row in self._fetchall_impl()] + l = [self._process_row(self, row) for row in self._fetchall_impl()] self.close() return l def fetchmany(self, size=None): """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``.""" - l = [self._process_row(row) for row in self._fetchmany_impl(size)] + l = [self._process_row(self, row) for row in self._fetchmany_impl(size)] if len(l) == 0: self.close() return l @@ -1343,7 +1312,7 @@ class ResultProxy(object): """Fetch one row, just like DBAPI ``cursor.fetchone()``.""" row = self._fetchone_impl() if row is not None: - return self._process_row(row) + return self._process_row(self, row) else: self.close() return None @@ -1353,7 +1322,7 @@ class ResultProxy(object): row = self._fetchone_impl() try: if row is not None: - return self._process_row(row)[0] + return self._process_row(self, row)[0] else: return None finally: @@ -1425,11 +1394,9 @@ class BufferedColumnResultProxy(ResultProxy): def _get_col(self, row, key): rec = self._key_cache[key] return row[rec[2]] - - def _process_row(self, row): - sup = super(BufferedColumnResultProxy, self) - row = [sup._get_col(row, i) for i in xrange(len(row))] - return RowProxy(self, row) + + def _row_processor(self): + return BufferedColumnRow def fetchall(self): l = [] @@ -1523,6 +1490,11 @@ class RowProxy(object): def __len__(self): return len(self.__row) +class BufferedColumnRow(RowProxy): + def __init__(self, parent, row): + row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))] + super(BufferedColumnRow, self).__init__(parent, row) + class SchemaIterator(schema.SchemaVisitor): """A visitor that can gather text into a buffer and execute the contents of the buffer.""" @@ -1590,11 +1562,11 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = sql.select([default.arg]).compile(bind=self.connection) + c = expression.select([default.arg]).compile(bind=self.connection) return self.connection._execute_compiled(c).scalar() def visit_column_onupdate(self, onupdate): - if isinstance(onupdate.arg, sql.ClauseElement): + if isinstance(onupdate.arg, expression.ClauseElement): return self.exec_default_sql(onupdate) elif callable(onupdate.arg): return onupdate.arg(self.context) @@ -1602,7 +1574,7 @@ class DefaultRunner(schema.SchemaVisitor): return onupdate.arg def visit_column_default(self, default): - if isinstance(default.arg, sql.ClauseElement): + if isinstance(default.arg, expression.ClauseElement): return self.exec_default_sql(default) elif callable(default.arg): return default.arg(self.context) |