diff options
Diffstat (limited to 'lib/sqlalchemy/engine/base.py')
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 518 |
1 files changed, 268 insertions, 250 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 0baaeb826..d8a9c5299 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -83,7 +83,7 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() def type_descriptor(self, typeobj): - """Trasform the type from generic to database-specific. + """Transform the type from generic to database-specific. Provides a database-specific TypeEngine object, given the generic object which comes from the types module. Subclasses @@ -105,6 +105,10 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() + def supports_alter(self): + """return True if the database supports ALTER TABLE.""" + raise NotImplementedError() + def max_identifier_length(self): """Return the maximum length of identifier names. @@ -118,32 +122,43 @@ class Dialect(sql.AbstractDialect): def supports_sane_rowcount(self): """Indicate whether the dialect properly implements statements rowcount. - Provided to indicate when MySQL is being used, which does not - have standard behavior for the "rowcount" function on a statement handle. + This was needed for MySQL which had non-standard behavior of rowcount, + but this issue has since been resolved. """ raise NotImplementedError() - def schemagenerator(self, engine, proxy, **params): + def schemagenerator(self, connection, **kwargs): """Return a ``schema.SchemaVisitor`` instance that can generate schemas. + connection + a Connection to use for statement execution + `schemagenerator()` is called via the `create()` method on Table, Index, and others. """ raise NotImplementedError() - def schemadropper(self, engine, proxy, **params): + def schemadropper(self, connection, **kwargs): """Return a ``schema.SchemaVisitor`` instance that can drop schemas. + connection + a Connection to use for statement execution + `schemadropper()` is called via the `drop()` method on Table, Index, and others. """ raise NotImplementedError() - def defaultrunner(self, engine, proxy, **params): - """Return a ``schema.SchemaVisitor`` instance that can execute defaults.""" + def defaultrunner(self, connection, **kwargs): + """Return a ``schema.SchemaVisitor`` instance that can execute defaults. + + connection + a Connection to use for statement execution + + """ raise NotImplementedError() @@ -154,7 +169,6 @@ class Dialect(sql.AbstractDialect): ansisql.ANSICompiler, and will produce a string representation of the given ClauseElement and `parameters` dictionary. - `compiler()` is called within the context of the compile() method. """ raise NotImplementedError() @@ -188,23 +202,13 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def dbapi(self): - """Establish a connection to the database. - - Subclasses override this method to provide the DBAPI module - used to establish connections. - """ - - raise NotImplementedError() - def get_default_schema_name(self, connection): """Return the currently selected schema given a connection""" raise NotImplementedError() - def execution_context(self): + def create_execution_context(self, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None): """Return a new ExecutionContext object.""" - raise NotImplementedError() def do_begin(self, connection): @@ -232,15 +236,6 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def create_cursor(self, connection): - """Return a new cursor generated from the given connection.""" - - raise NotImplementedError() - - def create_result_proxy_args(self, connection, cursor): - """Return a dictionary of arguments that should be passed to ResultProxy().""" - - raise NotImplementedError() def compile(self, clauseelement, parameters=None): """Compile the given ClauseElement using this Dialect. @@ -255,42 +250,74 @@ class Dialect(sql.AbstractDialect): class ExecutionContext(object): """A messenger object for a Dialect that corresponds to a single execution. + ExecutionContext should have these datamembers: + + connection + Connection object which initiated the call to the + dialect to create this ExecutionContext. + + dialect + dialect which created this ExecutionContext. + + cursor + DBAPI cursor procured from the connection + + compiled + if passed to constructor, sql.Compiled object being executed + + compiled_parameters + if passed to constructor, sql.ClauseParameters object + + statement + string version of the statement to be executed. Is either + passed to the constructor, or must be created from the + sql.Compiled object by the time pre_exec() has completed. + + parameters + "raw" parameters suitable for direct execution by the + dialect. Either passed to the constructor, or must be + created from the sql.ClauseParameters object by the time + pre_exec() has completed. + + The Dialect should provide an ExecutionContext via the create_execution_context() method. The `pre_exec` and `post_exec` - methods will be called for compiled statements, afterwhich it is - expected that the various methods `last_inserted_ids`, - `last_inserted_params`, etc. will contain appropriate values, if - applicable. + methods will be called for compiled statements. + """ - def pre_exec(self, engine, proxy, compiled, parameters): - """Called before an execution of a compiled statement. + def create_cursor(self): + """Return a new cursor generated this ExecutionContext's connection.""" - `proxy` is a callable that takes a string statement and a bind - parameter list/dictionary. + raise NotImplementedError() + + def pre_exec(self): + """Called before an execution of a compiled statement. + + If compiled and compiled_parameters were passed to this + ExecutionContext, the `statement` and `parameters` datamembers + must be initialized after this statement is complete. """ raise NotImplementedError() - def post_exec(self, engine, proxy, compiled, parameters): + def post_exec(self): """Called after the execution of a compiled statement. - - `proxy` is a callable that takes a string statement and a bind - parameter list/dictionary. + + If compiled was passed to this ExecutionContext, + the `last_insert_ids`, `last_inserted_params`, etc. + datamembers should be available after this method + completes. """ raise NotImplementedError() - - def get_rowcount(self, cursor): - """Return the count of rows updated/deleted for an UPDATE/DELETE statement.""" - + + def get_result_proxy(self): + """return a ResultProxy corresponding to this ExecutionContext.""" raise NotImplementedError() - - def supports_sane_rowcount(self): - """Indicate if the "rowcount" DBAPI cursor function works properly. - - Currently, MySQLDB does not properly implement this function. - """ + + def get_rowcount(self): + """Return the count of rows updated/deleted for an UPDATE/DELETE statement.""" raise NotImplementedError() @@ -299,7 +326,7 @@ class ExecutionContext(object): This does not apply to straight textual clauses; only to ``sql.Insert`` objects compiled against a ``schema.Table`` object, - which are executed via `statement.execute()`. The order of + which are executed via `execute()`. The order of items in the list is the same as that of the Table's 'primary_key' attribute. @@ -337,7 +364,7 @@ class ExecutionContext(object): raise NotImplementedError() -class Connectable(object): +class Connectable(sql.Executor): """Interface for an object that can provide an Engine and a Connection object which correponds to that Engine.""" def contextual_connect(self): @@ -362,6 +389,7 @@ class Connectable(object): raise NotImplementedError() engine = property(_not_impl, doc="The Engine which this Connectable is associated with.") + dialect = property(_not_impl, doc="Dialect which this Connectable is associated with.") class Connection(Connectable): """Represent a single DBAPI connection returned from the underlying connection pool. @@ -385,7 +413,8 @@ class Connection(Connectable): except AttributeError: raise exceptions.InvalidRequestError("This Connection is closed") - engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated (read only)") + engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.") + dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.") connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.") should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.") @@ -429,7 +458,7 @@ class Connection(Connectable): """When no Transaction is present, this is called after executions to provide "autocommit" behavior.""" # TODO: have the dialect determine if autocommit can be set on the connection directly without this # extra step - if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()): + if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip(), re.I): self._commit_impl() def _autorollback(self): @@ -448,6 +477,9 @@ class Connection(Connectable): def scalar(self, object, *multiparams, **params): return self.execute(object, *multiparams, **params).scalar() + def compiler(self, statement, parameters, **kwargs): + return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs) + def execute(self, object, *multiparams, **params): for c in type(object).__mro__: if c in Connection.executors: @@ -456,7 +488,7 @@ class Connection(Connectable): raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object))) def execute_default(self, default, **kwargs): - return default.accept_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs)) + return default.accept_visitor(self.__engine.dialect.defaultrunner(self)) def execute_text(self, statement, *multiparams, **params): if len(multiparams) == 0: @@ -465,9 +497,9 @@ class Connection(Connectable): parameters = multiparams[0] else: parameters = list(multiparams) - cursor = self._execute_raw(statement, parameters) - rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor) - return ResultProxy(self.__engine, self, cursor, **rpargs) + context = self._create_execution_context(statement=statement, parameters=parameters) + self._execute_raw(context) + return context.get_result_proxy() def _params_to_listofdicts(self, *multiparams, **params): if len(multiparams) == 0: @@ -491,29 +523,57 @@ class Connection(Connectable): param = multiparams[0] else: param = params - return self.execute_compiled(elem.compile(engine=self.__engine, parameters=param), *multiparams, **params) + return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params) def execute_compiled(self, compiled, *multiparams, **params): """Execute a sql.Compiled object.""" if not compiled.can_execute: raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled))) - cursor = self.__engine.dialect.create_cursor(self.connection) parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)] if len(parameters) == 1: parameters = parameters[0] - def proxy(statement=None, parameters=None): - if statement is None: - return cursor - - parameters = self.__engine.dialect.convert_compiled_params(parameters) - self._execute_raw(statement, parameters, cursor=cursor, context=context) - return cursor - context = self.__engine.dialect.create_execution_context() - context.pre_exec(self.__engine, proxy, compiled, parameters) - proxy(unicode(compiled), parameters) - context.post_exec(self.__engine, proxy, compiled, parameters) - rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor) - return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs) + context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters) + context.pre_exec() + self._execute_raw(context) + context.post_exec() + return context.get_result_proxy() + + def _create_execution_context(self, **kwargs): + return self.__engine.dialect.create_execution_context(connection=self, **kwargs) + + def _execute_raw(self, context): + self.__engine.logger.info(context.statement) + self.__engine.logger.info(repr(context.parameters)) + if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], dict)): + self._executemany(context) + else: + self._execute(context) + self._autocommit(context.statement) + + def _execute(self, context): + if context.parameters is None: + if context.dialect.positional: + context.parameters = () + else: + context.parameters = {} + try: + context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context) + except Exception, e: + self._autorollback() + #self._rollback_impl() + if self.__close_with_result: + self.close() + raise exceptions.SQLError(context.statement, context.parameters, e) + + def _executemany(self, context): + try: + context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context) + except Exception, e: + self._autorollback() + #self._rollback_impl() + if self.__close_with_result: + self.close() + raise exceptions.SQLError(context.statement, context.parameters, e) # poor man's multimethod/generic function thingy executors = { @@ -525,17 +585,17 @@ class Connection(Connectable): } def create(self, entity, **kwargs): - """Create a table or index given an appropriate schema object.""" + """Create a Table or Index given an appropriate Schema object.""" return self.__engine.create(entity, connection=self, **kwargs) def drop(self, entity, **kwargs): - """Drop a table or index given an appropriate schema object.""" + """Drop a Table or Index given an appropriate Schema object.""" return self.__engine.drop(entity, connection=self, **kwargs) def reflecttable(self, table, **kwargs): - """Reflect the columns in the given table from the database.""" + """Reflect the columns in the given string table name from the database.""" return self.__engine.reflecttable(table, connection=self, **kwargs) @@ -545,59 +605,6 @@ class Connection(Connectable): def run_callable(self, callable_): return callable_(self) - def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs): - if cursor is None: - cursor = self.__engine.dialect.create_cursor(self.connection) - if not self.__engine.dialect.supports_unicode_statements(): - # encode to ascii, with full error handling - statement = statement.encode('ascii') - self.__engine.logger.info(statement) - self.__engine.logger.info(repr(parameters)) - if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)): - self._executemany(cursor, statement, parameters, context=context) - else: - self._execute(cursor, statement, parameters, context=context) - self._autocommit(statement) - return cursor - - def _execute(self, c, statement, parameters, context=None): - if parameters is None: - if self.__engine.dialect.positional: - parameters = () - else: - parameters = {} - try: - self.__engine.dialect.do_execute(c, statement, parameters, context=context) - except Exception, e: - self._autorollback() - #self._rollback_impl() - if self.__close_with_result: - self.close() - raise exceptions.SQLError(statement, parameters, e) - - def _executemany(self, c, statement, parameters, context=None): - try: - self.__engine.dialect.do_executemany(c, statement, parameters, context=context) - except Exception, e: - self._autorollback() - #self._rollback_impl() - if self.__close_with_result: - self.close() - raise exceptions.SQLError(statement, parameters, e) - - def proxy(self, statement=None, parameters=None): - """Execute the given statement string and parameter object. - - The parameter object is expected to be the result of a call to - ``compiled.get_params()``. This callable is a generic version - of a connection/cursor-specific callable that is produced - within the execute_compiled method, and is used for objects - that require this style of proxy when outside of an - execute_compiled method, primarily the DefaultRunner. - """ - parameters = self.__engine.dialect.convert_compiled_params(parameters) - return self._execute_raw(statement, parameters) - class Transaction(object): """Represent a Transaction in progress. @@ -630,7 +637,7 @@ class Transaction(object): self.__connection._commit_impl() self.__is_active = False -class Engine(sql.Executor, Connectable): +class Engine(Connectable): """ Connects a ConnectionProvider, a Dialect and a CompilerFactory together to provide a default implementation of SchemaEngine. @@ -638,12 +645,13 @@ class Engine(sql.Executor, Connectable): def __init__(self, connection_provider, dialect, echo=None): self.connection_provider = connection_provider - self.dialect=dialect + self._dialect=dialect self.echo = echo self.logger = logging.instance_logger(self) name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name']) engine = property(lambda s:s) + dialect = property(lambda s:s._dialect) echo = logging.echo_property() def dispose(self): @@ -678,11 +686,11 @@ class Engine(sql.Executor, Connectable): def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: - conn = self.contextual_connect() + conn = self.contextual_connect(close_with_result=False) else: conn = connection try: - element.accept_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs)) + element.accept_visitor(visitorcallable(conn, **kwargs)) finally: if connection is None: conn.close() @@ -807,55 +815,39 @@ class ResultProxy(object): def convert_result_value(self, arg, engine): raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) - def __new__(cls, *args, **kwargs): - if cls is ResultProxy and kwargs.has_key('should_prefetch') and kwargs['should_prefetch']: - return PrefetchingResultProxy(*args, **kwargs) - else: - return object.__new__(cls, *args, **kwargs) - - def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None, column_labels=None, should_prefetch=None): + def __init__(self, context): """ResultProxy objects are constructed via the execute() method on SQLEngine.""" - - self.connection = connection - self.dialect = engine.dialect - self.cursor = cursor - self.engine = engine + self.context = context self.closed = False - self.column_labels = column_labels - if executioncontext is not None: - self.__executioncontext = executioncontext - self.rowcount = executioncontext.get_rowcount(cursor) - else: - self.rowcount = cursor.rowcount - self.__key_cache = {} - self.__echo = engine.echo == 'debug' - metadata = cursor.description - self.props = {} - self.keys = [] - i = 0 + self.cursor = context.cursor + self.__echo = logging.is_debug_enabled(context.engine.logger) + self._init_metadata() + dialect = property(lambda s:s.context.dialect) + rowcount = property(lambda s:s.context.get_rowcount()) + connection = property(lambda s:s.context.connection) + + def _init_metadata(self): + if hasattr(self, '_ResultProxy__props'): + return + self.__key_cache = {} + self.__props = {} + self.__keys = [] + metadata = self.cursor.description if metadata is not None: - for item in metadata: + for i, item in enumerate(metadata): # sqlite possibly prepending table name to colnames so strip - colname = item[0].split('.')[-1].lower() - if typemap is not None: - rec = (typemap.get(colname, types.NULLTYPE), i) + colname = item[0].split('.')[-1] + if self.context.typemap is not None: + rec = (self.context.typemap.get(colname.lower(), types.NULLTYPE), i) else: rec = (types.NULLTYPE, i) if rec[0] is None: raise DBAPIError("None for metadata " + colname) - if self.props.setdefault(colname, rec) is not rec: - self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0) - self.keys.append(colname) - self.props[i] = rec - i+=1 - - def _executioncontext(self): - try: - return self.__executioncontext - except AttributeError: - raise exceptions.InvalidRequestError("This ResultProxy does not have an execution context with which to complete this operation. Execution contexts are not generated for literal SQL execution.") - executioncontext = property(_executioncontext) + if self.__props.setdefault(colname.lower(), rec) is not rec: + self.__props[colname.lower()] = (ResultProxy.AmbiguousColumn(colname), 0) + self.__keys.append(colname) + self.__props[i] = rec def close(self): """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution. @@ -867,13 +859,12 @@ class ResultProxy(object): This method is also called automatically when all result rows are exhausted. """ - if not self.closed: self.closed = True self.cursor.close() if self.connection.should_close_with_result and self.dialect.supports_autoclose_results: self.connection.close() - + def _convert_key(self, key): """Convert and cache a key. @@ -882,25 +873,26 @@ class ResultProxy(object): metadata; then cache it locally for quick re-access. """ - try: + if key in self.__key_cache: return self.__key_cache[key] - except KeyError: - if isinstance(key, int) and key in self.props: - rec = self.props[key] - elif isinstance(key, basestring) and key.lower() in self.props: - rec = self.props[key.lower()] + else: + if isinstance(key, int) and key in self.__props: + rec = self.__props[key] + elif isinstance(key, basestring) and key.lower() in self.__props: + rec = self.__props[key.lower()] elif isinstance(key, sql.ColumnElement): - label = self.column_labels.get(key._label, key.name).lower() - if label in self.props: - rec = self.props[label] + label = self.context.column_labels.get(key._label, key.name).lower() + if label in self.__props: + rec = self.__props[label] if not "rec" in locals(): raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (repr(key))) self.__key_cache[key] = rec return rec - - + + keys = property(lambda s:s.__keys) + def _has_key(self, row, key): try: self._convert_key(key) @@ -908,10 +900,6 @@ class ResultProxy(object): except KeyError: return False - def _get_col(self, row, key): - rec = self._convert_key(key) - return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect) - def __iter__(self): while True: row = self.fetchone() @@ -926,7 +914,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.last_inserted_ids() + return self.context.last_inserted_ids() def last_updated_params(self): """Return ``last_updated_params()`` from the underlying ExecutionContext. @@ -934,7 +922,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.last_updated_params() + return self.context.last_updated_params() def last_inserted_params(self): """Return ``last_inserted_params()`` from the underlying ExecutionContext. @@ -942,7 +930,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.last_inserted_params() + return self.context.last_inserted_params() def lastrow_has_defaults(self): """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext. @@ -950,7 +938,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.lastrow_has_defaults() + return self.context.lastrow_has_defaults() def supports_sane_rowcount(self): """Return ``supports_sane_rowcount()`` from the underlying ExecutionContext. @@ -958,71 +946,122 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.supports_sane_rowcount() + return self.context.supports_sane_rowcount() + def _get_col(self, row, key): + rec = self._convert_key(key) + return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect) + + def _fetchone_impl(self): + return self.cursor.fetchone() + def _fetchmany_impl(self, size=None): + return self.cursor.fetchmany(size) + def _fetchall_impl(self): + return self.cursor.fetchall() + + def _process_row(self, row): + return RowProxy(self, row) + def fetchall(self): """Fetch all rows, just like DBAPI ``cursor.fetchall()``.""" - l = [] - for row in self.cursor.fetchall(): - l.append(RowProxy(self, row)) + l = [self._process_row(row) for row in self._fetchall_impl()] self.close() return l def fetchmany(self, size=None): """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``.""" - if size is None: - rows = self.cursor.fetchmany() - else: - rows = self.cursor.fetchmany(size) - l = [] - for row in rows: - l.append(RowProxy(self, row)) + l = [self._process_row(row) for row in self._fetchmany_impl(size)] if len(l) == 0: self.close() return l def fetchone(self): """Fetch one row, just like DBAPI ``cursor.fetchone()``.""" - - row = self.cursor.fetchone() + row = self._fetchone_impl() if row is not None: - return RowProxy(self, row) + return self._process_row(row) else: self.close() return None def scalar(self): """Fetch the first column of the first row, and close the result set.""" - - row = self.cursor.fetchone() + row = self._fetchone_impl() try: if row is not None: - return RowProxy(self, row)[0] + return self._process_row(row)[0] else: return None finally: self.close() -class PrefetchingResultProxy(ResultProxy): +class BufferedRowResultProxy(ResultProxy): + def _init_metadata(self): + self.__buffer_rows() + super(BufferedRowResultProxy, self)._init_metadata() + + # this is a "growth chart" for the buffering of rows. + # each successive __buffer_rows call will use the next + # value in the list for the buffer size until the max + # is reached + size_growth = { + 1 : 5, + 5 : 10, + 10 : 20, + 20 : 50, + 50 : 100 + } + + def __buffer_rows(self): + size = getattr(self, '_bufsize', 1) + self.__rowbuffer = self.cursor.fetchmany(size) + #self.context.engine.logger.debug("Buffered %d rows" % size) + self._bufsize = self.size_growth.get(size, size) + + def _fetchone_impl(self): + if self.closed: + return None + if len(self.__rowbuffer) == 0: + self.__buffer_rows() + if len(self.__rowbuffer) == 0: + return None + return self.__rowbuffer.pop(0) + + def _fetchmany_impl(self, size=None): + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + return self.__rowbuffer + list(self.cursor.fetchall()) + +class BufferedColumnResultProxy(ResultProxy): """ResultProxy that loads all columns into memory each time fetchone() is called. If fetchmany() or fetchall() are called, the full grid of results is fetched. """ - def _get_col(self, row, key): rec = self._convert_key(key) return row[rec[1]] + + def _process_row(self, row): + sup = super(BufferedColumnResultProxy, self) + row = [sup._get_col(row, i) for i in xrange(len(row))] + return RowProxy(self, row) def fetchall(self): l = [] while True: row = self.fetchone() - if row is not None: - l.append(row) - else: + if row is None: break + l.append(row) return l def fetchmany(self, size=None): @@ -1031,24 +1070,13 @@ class PrefetchingResultProxy(ResultProxy): l = [] for i in xrange(size): row = self.fetchone() - if row is not None: - l.append(row) - else: + if row is None: break + l.append(row) return l - def fetchone(self): - sup = super(PrefetchingResultProxy, self) - row = self.cursor.fetchone() - if row is not None: - row = [sup._get_col(row, i) for i in xrange(len(row))] - return RowProxy(self, row) - else: - self.close() - return None - class RowProxy(object): - """Proxie a single cursor row for a parent ResultProxy. + """Proxy a single cursor row for a parent ResultProxy. Mostly follows "ordered dictionary" behavior, mapping result values to the string-based column name, the integer position of @@ -1063,7 +1091,7 @@ class RowProxy(object): self.__parent = parent self.__row = row if self.__parent._ResultProxy__echo: - self.__parent.engine.logger.debug("Row " + repr(row)) + self.__parent.context.engine.logger.debug("Row " + repr(row)) def close(self): """Close the parent ResultProxy.""" @@ -1115,20 +1143,10 @@ class RowProxy(object): class SchemaIterator(schema.SchemaVisitor): """A visitor that can gather text into a buffer and execute the contents of the buffer.""" - def __init__(self, engine, proxy, **params): + def __init__(self, connection): """Construct a new SchemaIterator. - - engine - the Engine used by this SchemaIterator - - proxy - a callable which takes a statement and bind parameters and - executes it, returning the cursor (the actual DBAPI cursor). - The callable should use the same cursor repeatedly. """ - - self.proxy = proxy - self.engine = engine + self.connection = connection self.buffer = StringIO.StringIO() def append(self, s): @@ -1140,7 +1158,7 @@ class SchemaIterator(schema.SchemaVisitor): """Execute the contents of the SchemaIterator's buffer.""" try: - return self.proxy(self.buffer.getvalue(), None) + return self.connection.execute(self.buffer.getvalue()) finally: self.buffer.truncate(0) @@ -1154,10 +1172,10 @@ class DefaultRunner(schema.SchemaVisitor): DefaultRunner to allow database-specific behavior. """ - def __init__(self, engine, proxy): - self.proxy = proxy - self.engine = engine - + def __init__(self, connection): + self.connection = connection + self.dialect = connection.dialect + def get_column_default(self, column): if column.default is not None: return column.default.accept_visitor(self) @@ -1188,8 +1206,8 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = sql.select([default.arg], engine=self.engine).compile() - return self.proxy(str(c), c.get_params()).fetchone()[0] + c = sql.select([default.arg]).compile(engine=self.connection) + return self.connection.execute_compiled(c).scalar() def visit_column_onupdate(self, onupdate): if isinstance(onupdate.arg, sql.ClauseElement): |