diff options
Diffstat (limited to 'lib/sqlalchemy/engine/base.py')
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 633 |
1 files changed, 439 insertions, 194 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index d0ca36515..fc4433a47 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -9,33 +9,10 @@ higher-level statement-construction, connection-management, execution and result contexts.""" from sqlalchemy import exceptions, sql, schema, util, types, logging -import StringIO, sys, re +import StringIO, sys, re, random -class ConnectionProvider(object): - """Define an interface that returns raw Connection objects (or compatible).""" - - def get_connection(self): - """Return a Connection or compatible object from a DBAPI which also contains a close() method. - - It is not defined what context this connection belongs to. It - may be newly connected, returned from a pool, part of some - other kind of context such as thread-local, or can be a fixed - member of this object. - """ - - raise NotImplementedError() - - def dispose(self): - """Release all resources corresponding to this ConnectionProvider. - - This includes any underlying connection pools. - """ - - raise NotImplementedError() - - -class Dialect(sql.AbstractDialect): +class Dialect(object): """Define the behavior of a specific database/DBAPI. Any aspect of metadata definition, SQL query generation, execution, @@ -70,11 +47,14 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def convert_compiled_params(self, parameters): - """Build DBAPI execute arguments from a [sqlalchemy.sql#ClauseParameters] instance. - - Returns an array or dictionary suitable to pass directly to this ``Dialect`` instance's DBAPI's - execute method. + def dbapi_type_map(self): + """return a mapping of DBAPI type objects present in this Dialect's DBAPI + mapped to TypeEngine implementations used by the dialect. + + This is used to apply types to result sets based on the DBAPI types + present in cursor.description; it only takes effect for result sets against + textual statements where no explicit typemap was present. Constructed SQL statements + always have type information explicitly embedded. """ raise NotImplementedError() @@ -149,11 +129,11 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def defaultrunner(self, connection, **kwargs): + def defaultrunner(self, execution_context): """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults. - connection - a [sqlalchemy.engine#Connection] to use for statement execution + execution_context + a [sqlalchemy.engine#ExecutionContext] to use for statement execution """ @@ -168,11 +148,12 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns=None): """Load table description from the database. Given a [sqlalchemy.engine#Connection] and a [sqlalchemy.schema#Table] object, reflect its - columns and properties from the database. + columns and properties from the database. If include_columns (a list or set) is specified, limit the autoload + to the given column names. """ raise NotImplementedError() @@ -222,6 +203,46 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() + def do_savepoint(self, connection, name): + """Create a savepoint with the given name on a SQLAlchemy connection.""" + + raise NotImplementedError() + + def do_rollback_to_savepoint(self, connection, name): + """Rollback a SQL Alchemy connection to the named savepoint.""" + + raise NotImplementedError() + + def do_release_savepoint(self, connection, name): + """Release the named savepoint on a SQL Alchemy connection.""" + + raise NotImplementedError() + + def do_begin_twophase(self, connection, xid): + """Begin a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_prepare_twophase(self, connection, xid): + """Prepare a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + """Rollback a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + """Commit a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_recover_twophase(self, connection): + """Recover list of uncommited prepared two phase transaction identifiers on the given connection.""" + + raise NotImplementedError() + def do_executemany(self, cursor, statement, parameters): """Provide an implementation of *cursor.executemany(statement, parameters)*.""" @@ -266,19 +287,18 @@ class ExecutionContext(object): compiled if passed to constructor, sql.Compiled object being executed - compiled_parameters - if passed to constructor, sql.ClauseParameters object - statement string version of the statement to be executed. Is either passed to the constructor, or must be created from the sql.Compiled object by the time pre_exec() has completed. parameters - "raw" parameters suitable for direct execution by the - dialect. Either passed to the constructor, or must be - created from the sql.ClauseParameters object by the time - pre_exec() has completed. + bind parameters passed to the execute() method. for + compiled statements, this is a dictionary or list + of dictionaries. for textual statements, it should + be in a format suitable for the dialect's paramstyle + (i.e. dict or list of dicts for non positional, + list or list of lists/tuples for positional). The Dialect should provide an ExecutionContext via the @@ -288,24 +308,28 @@ class ExecutionContext(object): """ def create_cursor(self): - """Return a new cursor generated this ExecutionContext's connection.""" + """Return a new cursor generated from this ExecutionContext's connection. + + Some dialects may wish to change the behavior of connection.cursor(), + such as postgres which may return a PG "server side" cursor. + """ raise NotImplementedError() - def pre_exec(self): + def pre_execution(self): """Called before an execution of a compiled statement. - If compiled and compiled_parameters were passed to this + If a compiled statement was passed to this ExecutionContext, the `statement` and `parameters` datamembers must be initialized after this statement is complete. """ raise NotImplementedError() - def post_exec(self): + def post_execution(self): """Called after the execution of a compiled statement. - If compiled was passed to this ExecutionContext, + If a compiled statement was passed to this ExecutionContext, the `last_insert_ids`, `last_inserted_params`, etc. datamembers should be available after this method completes. @@ -313,8 +337,11 @@ class ExecutionContext(object): raise NotImplementedError() - def get_result_proxy(self): - """return a ResultProxy corresponding to this ExecutionContext.""" + def result(self): + """return a result object corresponding to this ExecutionContext. + + Returns a ResultProxy.""" + raise NotImplementedError() def get_rowcount(self): @@ -361,8 +388,88 @@ class ExecutionContext(object): raise NotImplementedError() +class Compiled(object): + """Represent a compiled SQL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + def __init__(self, dialect, statement, parameters, bind=None): + """Construct a new ``Compiled`` object. + + statement + ``ClauseElement`` to be compiled. + + parameters + Optional dictionary indicating a set of bind parameters + specified with this ``Compiled`` object. These parameters + are the *default* values corresponding to the + ``ClauseElement``'s ``_BindParamClauses`` when the + ``Compiled`` is executed. In the case of an ``INSERT`` or + ``UPDATE`` statement, these parameters will also result in + the creation of new ``_BindParamClause`` objects for each + key and will also affect the generated column list in an + ``INSERT`` statement and the ``SET`` clauses of an + ``UPDATE`` statement. The keys of the parameter dictionary + can either be the string names of columns or + ``_ColumnClause`` objects. + + bind + Optional Engine or Connection to compile this statement against. + """ + self.dialect = dialect + self.statement = statement + self.parameters = parameters + self.bind = bind + self.can_execute = statement.supports_execution() + + def compile(self): + """Produce the internal string representation of this element.""" + + raise NotImplementedError() + + def __str__(self): + """Return the string text of the generated SQL statement.""" + + raise NotImplementedError() + + def get_params(self, **params): + """Deprecated. use construct_params(). (supports unicode names) + """ + + return self.construct_params(params) + + def construct_params(self, params): + """Return the bind params for this compiled object. + + params is a dict of string/object pairs whos + values will override bind values compiled in + to the statement. + """ + raise NotImplementedError() + + def execute(self, *multiparams, **params): + """Execute this compiled object.""" + + e = self.bind + if e is None: + raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.") + return e._execute_compiled(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Execute this compiled object and return the result's scalar value.""" + + return self.execute(*multiparams, **params).scalar() + -class Connectable(sql.Executor): +class Connectable(object): """Interface for an object that can provide an Engine and a Connection object which correponds to that Engine.""" def contextual_connect(self): @@ -401,6 +508,7 @@ class Connection(Connectable): self.__connection = connection or engine.raw_connection() self.__transaction = None self.__close_with_result = close_with_result + self.__savepoint_seq = 0 def _get_connection(self): try: @@ -408,13 +516,18 @@ class Connection(Connectable): except AttributeError: raise exceptions.InvalidRequestError("This Connection is closed") + def _branch(self): + """return a new Connection which references this Connection's + engine and connection; but does not have close_with_result enabled.""" + + return Connection(self.__engine, self.__connection) + engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.") dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.") connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.") should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.") - - def _create_transaction(self, parent): - return Transaction(self, parent) + properties = property(lambda s: s._get_connection().properties, + doc="A set of per-DBAPI connection properties.") def connect(self): """connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly.""" @@ -448,12 +561,34 @@ class Connection(Connectable): self.__connection.detach() - def begin(self): + def begin(self, nested=False): if self.__transaction is None: - self.__transaction = self._create_transaction(None) - return self.__transaction + self.__transaction = RootTransaction(self) + elif nested: + self.__transaction = NestedTransaction(self, self.__transaction) else: - return self._create_transaction(self.__transaction) + return Transaction(self, self.__transaction) + return self.__transaction + + def begin_nested(self): + return self.begin(nested=True) + + def begin_twophase(self, xid=None): + if self.__transaction is not None: + raise exceptions.InvalidRequestError("Cannot start a two phase transaction when a transaction is already started.") + if xid is None: + xid = "_sa_%032x" % random.randint(0,2**128) + self.__transaction = TwoPhaseTransaction(self, xid) + return self.__transaction + + def recover_twophase(self): + return self.__engine.dialect.do_recover_twophase(self) + + def rollback_prepared(self, xid, recover=False): + self.__engine.dialect.do_rollback_twophase(self, xid, recover=recover) + + def commit_prepared(self, xid, recover=False): + self.__engine.dialect.do_commit_twophase(self, xid, recover=recover) def in_transaction(self): return self.__transaction is not None @@ -485,6 +620,45 @@ class Connection(Connectable): raise exceptions.SQLError(None, None, e) self.__transaction = None + def _savepoint_impl(self, name=None): + if name is None: + self.__savepoint_seq += 1 + name = '__sa_savepoint_%s' % self.__savepoint_seq + if self.__connection.is_valid: + self.__engine.dialect.do_savepoint(self, name) + return name + + def _rollback_to_savepoint_impl(self, name, context): + if self.__connection.is_valid: + self.__engine.dialect.do_rollback_to_savepoint(self, name) + self.__transaction = context + + def _release_savepoint_impl(self, name, context): + if self.__connection.is_valid: + self.__engine.dialect.do_release_savepoint(self, name) + self.__transaction = context + + def _begin_twophase_impl(self, xid): + if self.__connection.is_valid: + self.__engine.dialect.do_begin_twophase(self, xid) + + def _prepare_twophase_impl(self, xid): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_prepare_twophase(self, xid) + + def _rollback_twophase_impl(self, xid, is_prepared): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_rollback_twophase(self, xid, is_prepared) + self.__transaction = None + + def _commit_twophase_impl(self, xid, is_prepared): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_commit_twophase(self, xid, is_prepared) + self.__transaction = None + def _autocommit(self, statement): """When no Transaction is present, this is called after executions to provide "autocommit" behavior.""" # TODO: have the dialect determine if autocommit can be set on the connection directly without this @@ -495,7 +669,7 @@ class Connection(Connectable): def _autorollback(self): if not self.in_transaction(): self._rollback_impl() - + def close(self): try: c = self.__connection @@ -514,74 +688,66 @@ class Connection(Connectable): def execute(self, object, *multiparams, **params): for c in type(object).__mro__: if c in Connection.executors: - return Connection.executors[c](self, object, *multiparams, **params) + return Connection.executors[c](self, object, multiparams, params) else: raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object))) - def execute_default(self, default, **kwargs): - return default.accept_visitor(self.__engine.dialect.defaultrunner(self)) + def _execute_default(self, default, multiparams=None, params=None): + return self.__engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) + + def _execute_text(self, statement, multiparams, params): + parameters = self.__distill_params(multiparams, params) + context = self.__create_execution_context(statement=statement, parameters=parameters) + self.__execute_raw(context) + return context.result() - def execute_text(self, statement, *multiparams, **params): - if len(multiparams) == 0: + def __distill_params(self, multiparams, params): + if multiparams is None or len(multiparams) == 0: parameters = params or None - elif len(multiparams) == 1 and (isinstance(multiparams[0], list) or isinstance(multiparams[0], tuple) or isinstance(multiparams[0], dict)): + elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)): parameters = multiparams[0] else: parameters = list(multiparams) - context = self._create_execution_context(statement=statement, parameters=parameters) - self._execute_raw(context) - return context.get_result_proxy() - - def _params_to_listofdicts(self, *multiparams, **params): - if len(multiparams) == 0: - return [params] - elif len(multiparams) == 1: - if multiparams[0] == None: - return [{}] - elif isinstance (multiparams[0], list) or isinstance (multiparams[0], tuple): - return multiparams[0] - else: - return [multiparams[0]] - else: - return multiparams - - def execute_function(self, func, *multiparams, **params): - return self.execute_clauseelement(func.select(), *multiparams, **params) + return parameters + + def _execute_function(self, func, multiparams, params): + return self._execute_clauseelement(func.select(), multiparams, params) - def execute_clauseelement(self, elem, *multiparams, **params): - executemany = len(multiparams) > 0 + def _execute_clauseelement(self, elem, multiparams=None, params=None): + executemany = multiparams is not None and len(multiparams) > 0 if executemany: param = multiparams[0] else: param = params - return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params) + return self._execute_compiled(elem.compile(dialect=self.dialect, parameters=param), multiparams, params) - def execute_compiled(self, compiled, *multiparams, **params): + def _execute_compiled(self, compiled, multiparams=None, params=None): """Execute a sql.Compiled object.""" if not compiled.can_execute: raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled))) - parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)] - if len(parameters) == 1: - parameters = parameters[0] - context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters) - context.pre_exec() - self._execute_raw(context) - context.post_exec() - return context.get_result_proxy() - - def _create_execution_context(self, **kwargs): + + params = self.__distill_params(multiparams, params) + context = self.__create_execution_context(compiled=compiled, parameters=params) + + context.pre_execution() + self.__execute_raw(context) + context.post_execution() + return context.result() + + def __create_execution_context(self, **kwargs): return self.__engine.dialect.create_execution_context(connection=self, **kwargs) - def _execute_raw(self, context): - self.__engine.logger.info(context.statement) - self.__engine.logger.info(repr(context.parameters)) - if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], tuple) or isinstance(context.parameters[0], dict)): - self._executemany(context) + def __execute_raw(self, context): + if logging.is_info_enabled(self.__engine.logger): + self.__engine.logger.info(context.statement) + self.__engine.logger.info(repr(context.parameters)) + if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)): + self.__executemany(context) else: - self._execute(context) + self.__execute(context) self._autocommit(context.statement) - def _execute(self, context): + def __execute(self, context): if context.parameters is None: if context.dialect.positional: context.parameters = () @@ -592,19 +758,19 @@ class Connection(Connectable): except Exception, e: if self.dialect.is_disconnect(e): self.__connection.invalidate(e=e) - self.engine.connection_provider.dispose() + self.engine.dispose() self._autorollback() if self.__close_with_result: self.close() raise exceptions.SQLError(context.statement, context.parameters, e) - def _executemany(self, context): + def __executemany(self, context): try: context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context) except Exception, e: if self.dialect.is_disconnect(e): self.__connection.invalidate(e=e) - self.engine.connection_provider.dispose() + self.engine.dispose() self._autorollback() if self.__close_with_result: self.close() @@ -612,11 +778,11 @@ class Connection(Connectable): # poor man's multimethod/generic function thingy executors = { - sql._Function : execute_function, - sql.ClauseElement : execute_clauseelement, - sql.ClauseVisitor : execute_compiled, - schema.SchemaItem:execute_default, - str.__mro__[-2] : execute_text + sql._Function : _execute_function, + sql.ClauseElement : _execute_clauseelement, + sql.ClauseVisitor : _execute_compiled, + schema.SchemaItem:_execute_default, + str.__mro__[-2] : _execute_text } def create(self, entity, **kwargs): @@ -629,10 +795,10 @@ class Connection(Connectable): return self.__engine.drop(entity, connection=self, **kwargs) - def reflecttable(self, table, **kwargs): + def reflecttable(self, table, include_columns=None): """Reflect the columns in the given string table name from the database.""" - return self.__engine.reflecttable(table, connection=self, **kwargs) + return self.__engine.reflecttable(table, self, include_columns) def default_schema_name(self): return self.__engine.dialect.get_default_schema_name(self) @@ -647,39 +813,90 @@ class Transaction(object): """ def __init__(self, connection, parent): - self.__connection = connection - self.__parent = parent or self - self.__is_active = True - if self.__parent is self: - self.__connection._begin_impl() + self._connection = connection + self._parent = parent or self + self._is_active = True - connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction") - is_active = property(lambda s:s.__is_active) + connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction") + is_active = property(lambda s:s._is_active) def rollback(self): - if not self.__parent.__is_active: + if not self._parent._is_active: return - if self.__parent is self: - self.__connection._rollback_impl() - self.__is_active = False - else: - self.__parent.rollback() + self._is_active = False + self._do_rollback() + + def _do_rollback(self): + self._parent.rollback() def commit(self): - if not self.__parent.__is_active: + if not self._parent._is_active: raise exceptions.InvalidRequestError("This transaction is inactive") - if self.__parent is self: - self.__connection._commit_impl() - self.__is_active = False + self._is_active = False + self._do_commit() + + def _do_commit(self): + pass + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + if type is None and self._is_active: + self.commit() + else: + self.rollback() + +class RootTransaction(Transaction): + def __init__(self, connection): + super(RootTransaction, self).__init__(connection, None) + self._connection._begin_impl() + + def _do_rollback(self): + self._connection._rollback_impl() + + def _do_commit(self): + self._connection._commit_impl() + +class NestedTransaction(Transaction): + def __init__(self, connection, parent): + super(NestedTransaction, self).__init__(connection, parent) + self._savepoint = self._connection._savepoint_impl() + + def _do_rollback(self): + self._connection._rollback_to_savepoint_impl(self._savepoint, self._parent) + + def _do_commit(self): + self._connection._release_savepoint_impl(self._savepoint, self._parent) + +class TwoPhaseTransaction(Transaction): + def __init__(self, connection, xid): + super(TwoPhaseTransaction, self).__init__(connection, None) + self._is_prepared = False + self.xid = xid + self._connection._begin_twophase_impl(self.xid) + + def prepare(self): + if not self._parent._is_active: + raise exceptions.InvalidRequestError("This transaction is inactive") + self._connection._prepare_twophase_impl(self.xid) + self._is_prepared = True + + def _do_rollback(self): + self._connection._rollback_twophase_impl(self.xid, self._is_prepared) + + def commit(self): + self._connection._commit_twophase_impl(self.xid, self._is_prepared) class Engine(Connectable): """ - Connects a ConnectionProvider, a Dialect and a CompilerFactory together to + Connects a Pool, a Dialect and a CompilerFactory together to provide a default implementation of SchemaEngine. """ - def __init__(self, connection_provider, dialect, echo=None): - self.connection_provider = connection_provider + def __init__(self, pool, dialect, url, echo=None): + self.pool = pool + self.url = url self._dialect=dialect self.echo = echo self.logger = logging.instance_logger(self) @@ -688,10 +905,13 @@ class Engine(Connectable): engine = property(lambda s:s) dialect = property(lambda s:s._dialect, doc="the [sqlalchemy.engine#Dialect] in use by this engine.") echo = logging.echo_property() - url = property(lambda s:s.connection_provider.url, doc="The [sqlalchemy.engine.url#URL] object representing this ``Engine`` object's datasource.") + + def __repr__(self): + return 'Engine(%s)' % str(self.url) def dispose(self): - self.connection_provider.dispose() + self.pool.dispose() + self.pool = self.pool.recreate() def create(self, entity, connection=None, **kwargs): """Create a table or index within this engine's database connection given a schema.Table object.""" @@ -703,22 +923,22 @@ class Engine(Connectable): self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs) - def execute_default(self, default, **kwargs): + def _execute_default(self, default): connection = self.contextual_connect() try: - return connection.execute_default(default, **kwargs) + return connection._execute_default(default) finally: connection.close() def _func(self): - return sql._FunctionGenerator(engine=self) + return sql._FunctionGenerator(bind=self) func = property(_func) def text(self, text, *args, **kwargs): """Return a sql.text() object for performing literal queries.""" - return sql.text(text, engine=self, *args, **kwargs) + return sql.text(text, bind=self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: @@ -726,7 +946,7 @@ class Engine(Connectable): else: conn = connection try: - element.accept_visitor(visitorcallable(conn, **kwargs)) + visitorcallable(conn, **kwargs).traverse(element) finally: if connection is None: conn.close() @@ -775,12 +995,12 @@ class Engine(Connectable): def scalar(self, statement, *multiparams, **params): return self.execute(statement, *multiparams, **params).scalar() - def execute_compiled(self, compiled, *multiparams, **params): + def _execute_compiled(self, compiled, multiparams, params): connection = self.contextual_connect(close_with_result=True) - return connection.execute_compiled(compiled, *multiparams, **params) + return connection._execute_compiled(compiled, multiparams, params) def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, engine=self, **kwargs) + return self.dialect.compiler(statement, parameters, bind=self, **kwargs) def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -795,7 +1015,7 @@ class Engine(Connectable): return Connection(self, close_with_result=close_with_result, **kwargs) - def reflecttable(self, table, connection=None): + def reflecttable(self, table, connection=None, include_columns=None): """Given a Table object, reflects its columns and properties from the database.""" if connection is None: @@ -803,7 +1023,7 @@ class Engine(Connectable): else: conn = connection try: - self.dialect.reflecttable(conn, table) + self.dialect.reflecttable(conn, table, include_columns) finally: if connection is None: conn.close() @@ -814,7 +1034,7 @@ class Engine(Connectable): def raw_connection(self): """Return a DBAPI connection.""" - return self.connection_provider.get_connection() + return self.pool.connect() def log(self, msg): """Log a message using this SQLEngine's logger stream.""" @@ -858,28 +1078,42 @@ class ResultProxy(object): self.closed = False self.cursor = context.cursor self.__echo = logging.is_debug_enabled(context.engine.logger) - self._init_metadata() - - rowcount = property(lambda s:s.context.get_rowcount()) - connection = property(lambda s:s.context.connection) + if context.is_select(): + self._init_metadata() + self._rowcount = None + else: + self._rowcount = context.get_rowcount() + self.close() + + connection = property(lambda self:self.context.connection) + def _get_rowcount(self): + if self._rowcount is not None: + return self._rowcount + else: + return self.context.get_rowcount() + rowcount = property(_get_rowcount) lastrowid = property(lambda s:s.cursor.lastrowid) + out_parameters = property(lambda s:s.context.out_parameters) def _init_metadata(self): if hasattr(self, '_ResultProxy__props'): return - self.__key_cache = {} self.__props = {} + self._key_cache = self._create_key_cache() self.__keys = [] metadata = self.cursor.description if metadata is not None: + typemap = self.dialect.dbapi_type_map() + for i, item in enumerate(metadata): # sqlite possibly prepending table name to colnames so strip - colname = item[0].split('.')[-1] + colname = self.dialect.decode_result_columnname(item[0].split('.')[-1]) if self.context.typemap is not None: - type = self.context.typemap.get(colname.lower(), types.NULLTYPE) + type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE)) else: - type = types.NULLTYPE + type = typemap.get(item[1], types.NULLTYPE) + rec = (type, type.dialect_impl(self.dialect), i) if rec[0] is None: @@ -889,6 +1123,33 @@ class ResultProxy(object): self.__keys.append(colname) self.__props[i] = rec + if self.__echo: + self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata]))) + + def _create_key_cache(self): + # local copies to avoid circular ref against 'self' + props = self.__props + context = self.context + def lookup_key(key): + """Given a key, which could be a ColumnElement, string, etc., + matches it to the appropriate key we got from the result set's + metadata; then cache it locally for quick re-access.""" + + if isinstance(key, int) and key in props: + rec = props[key] + elif isinstance(key, basestring) and key.lower() in props: + rec = props[key.lower()] + elif isinstance(key, sql.ColumnElement): + label = context.column_labels.get(key._label, key.name).lower() + if label in props: + rec = props[label] + + if not "rec" in locals(): + raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) + + return rec + return util.PopulateDict(lookup_key) + def close(self): """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution. @@ -904,38 +1165,12 @@ class ResultProxy(object): self.cursor.close() if self.connection.should_close_with_result: self.connection.close() - - def _convert_key(self, key): - """Convert and cache a key. - - Given a key, which could be a ColumnElement, string, etc., - matches it to the appropriate key we got from the result set's - metadata; then cache it locally for quick re-access. - """ - - if key in self.__key_cache: - return self.__key_cache[key] - else: - if isinstance(key, int) and key in self.__props: - rec = self.__props[key] - elif isinstance(key, basestring) and key.lower() in self.__props: - rec = self.__props[key.lower()] - elif isinstance(key, sql.ColumnElement): - label = self.context.column_labels.get(key._label, key.name).lower() - if label in self.__props: - rec = self.__props[label] - - if not "rec" in locals(): - raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) - - self.__key_cache[key] = rec - return rec keys = property(lambda s:s.__keys) def _has_key(self, row, key): try: - self._convert_key(key) + self._key_cache[key] return True except KeyError: return False @@ -989,7 +1224,7 @@ class ResultProxy(object): return self.context.supports_sane_rowcount() def _get_col(self, row, key): - rec = self._convert_key(key) + rec = self._key_cache[key] return rec[1].convert_result_value(row[rec[2]], self.dialect) def _fetchone_impl(self): @@ -1101,7 +1336,7 @@ class BufferedColumnResultProxy(ResultProxy): """ def _get_col(self, row, key): - rec = self._convert_key(key) + rec = self._key_cache[key] return row[rec[2]] def _process_row(self, row): @@ -1152,6 +1387,9 @@ class RowProxy(object): self.__parent.close() + def __contains__(self, key): + return self.__parent._has_key(self.__row, key) + def __iter__(self): for i in range(0, len(self.__row)): yield self.__parent._get_col(self.__row, i) @@ -1168,7 +1406,11 @@ class RowProxy(object): return self.__parent._has_key(self.__row, key) def __getitem__(self, key): - return self.__parent._get_col(self.__row, key) + if isinstance(key, slice): + indices = key.indices(len(self)) + return tuple([self.__parent._get_col(self.__row, i) for i in range(*indices)]) + else: + return self.__parent._get_col(self.__row, key) def __getattr__(self, name): try: @@ -1226,19 +1468,22 @@ class DefaultRunner(schema.SchemaVisitor): DefaultRunner to allow database-specific behavior. """ - def __init__(self, connection): - self.connection = connection - self.dialect = connection.dialect + def __init__(self, context): + self.context = context + # branch the connection so it doesnt close after result + self.connection = context.connection._branch() + dialect = property(lambda self:self.context.dialect) + def get_column_default(self, column): if column.default is not None: - return column.default.accept_visitor(self) + return self.traverse_single(column.default) else: return None def get_column_onupdate(self, column): if column.onupdate is not None: - return column.onupdate.accept_visitor(self) + return self.traverse_single(column.onupdate) else: return None @@ -1260,14 +1505,14 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = sql.select([default.arg]).compile(engine=self.connection) - return self.connection.execute_compiled(c).scalar() + c = sql.select([default.arg]).compile(bind=self.connection) + return self.connection._execute_compiled(c).scalar() def visit_column_onupdate(self, onupdate): if isinstance(onupdate.arg, sql.ClauseElement): return self.exec_default_sql(onupdate) elif callable(onupdate.arg): - return onupdate.arg() + return onupdate.arg(self.context) else: return onupdate.arg @@ -1275,6 +1520,6 @@ class DefaultRunner(schema.SchemaVisitor): if isinstance(default.arg, sql.ClauseElement): return self.exec_default_sql(default) elif callable(default.arg): - return default.arg() + return default.arg(self.context) else: return default.arg |