diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-09-01 22:42:51 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-09-01 22:42:51 +0000 |
commit | 8386fc6dc596c74f5cc9981504e274beff8d69cc (patch) | |
tree | 7299ab6dd5cc7703b496064284817a6f130f1d7c | |
parent | e04535a79a7528440960575e3623fa620290e026 (diff) | |
download | sqlalchemy-8386fc6dc596c74f5cc9981504e274beff8d69cc.tar.gz |
sequence pre-executes dont create an ExecutionContext, use straight cursor
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 56 | ||||
-rw-r--r-- | test/sql/constraints.py | 4 |
4 files changed, 41 insertions, 33 deletions
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index fb5b512e2..5b852c185 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -672,12 +672,8 @@ class OracleSchemaDropper(compiler.SchemaDropper): self.execute() class OracleDefaultRunner(base.DefaultRunner): - def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection) - return self.connection.execute(c).scalar() - def visit_sequence(self, seq): - return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar() + return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL") class OracleIdentifierPreparer(compiler.IdentifierPreparer): def format_savepoint(self, savepoint): diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 701114b17..f5eaf47f2 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -611,9 +611,9 @@ class PGSchemaDropper(compiler.SchemaDropper): class PGDefaultRunner(base.DefaultRunner): def get_column_default(self, column, isinsert=True): if column.primary_key: - # passive defaults on primary keys have to be overridden + # pre-execute passive defaults on primary keys if isinstance(column.default, schema.PassiveDefault): - return self.connection.execute("select %s" % column.default.arg).scalar() + return self.execute_string("select %s" % column.default.arg) elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting @@ -622,13 +622,13 @@ class PGDefaultRunner(base.DefaultRunner): exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - return self.connection.execute(exc).scalar() + return self.execute_string(exc.encode(self.dialect.encoding)) return super(PGDefaultRunner, self).get_column_default(column) def visit_sequence(self, seq): if not seq.optional: - return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar() + return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).encode(self.dialect.encoding)) else: return None diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 1ab05fe03..45d84f90d 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -837,45 +837,50 @@ class Connection(Connectable): return self.__engine.dialect.create_execution_context(connection=self, **kwargs) def __execute_raw(self, context): - if self.__engine._should_log_info: - self.__engine.logger.info(context.statement) - self.__engine.logger.info(repr(context.parameters)) if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)): - self.__executemany(context) + self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context) else: - self.__execute(context) + if context.parameters is None: + if context.dialect.positional: + parameters = () + else: + parameters = {} + else: + parameters = context.parameters + self._cursor_execute(context.cursor, context.statement, parameters, context=context) self._autocommit(context) - def __execute(self, context): - if context.parameters is None: - if context.dialect.positional: - context.parameters = () - else: - context.parameters = {} + def _cursor_execute(self, cursor, statement, parameters, context=None): + if self.__engine._should_log_info: + self.__engine.logger.info(statement) + self.__engine.logger.info(repr(parameters)) try: - context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context) + self.dialect.do_execute(cursor, statement, parameters, context=context) except Exception, e: if self.dialect.is_disconnect(e): self.__connection.invalidate(e=e) self.engine.dispose() - context.cursor.close() + cursor.close() self._autorollback() if self.__close_with_result: self.close() - raise exceptions.DBAPIError.instance(context.statement, context.parameters, e) + raise exceptions.DBAPIError.instance(statement, parameters, e) - def __executemany(self, context): + def _cursor_executemany(self, cursor, statement, parameters, context=None): + if self.__engine._should_log_info: + self.__engine.logger.info(statement) + self.__engine.logger.info(repr(parameters)) try: - context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context) + self.dialect.do_executemany(cursor, statement, parameters, context=context) except Exception, e: if self.dialect.is_disconnect(e): self.__connection.invalidate(e=e) self.engine.dispose() - context.cursor.close() + cursor.close() self._autorollback() if self.__close_with_result: self.close() - raise exceptions.DBAPIError.instance(context.statement, context.parameters, e) + raise exceptions.DBAPIError.instance(statement, parameters, e) # poor man's multimethod/generic function thingy executors = { @@ -1632,7 +1637,6 @@ class DefaultRunner(schema.SchemaVisitor): def __init__(self, context): self.context = context - self.connection = context._connection._branch() self.dialect = context.dialect def get_column_default(self, column): @@ -1665,9 +1669,17 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = expression.select([default.arg]).compile(bind=self.connection) - return self.connection._execute_compiled(c).scalar() - + conn = self.context.connection + c = expression.select([default.arg]).compile(bind=conn) + return conn._execute_compiled(c).scalar() + + def execute_string(self, stmt, params=None): + """execute a string statement, using the raw cursor, + and return a scalar result.""" + conn = self.context._connection + conn._cursor_execute(self.context.cursor, stmt, params) + return self.context.cursor.fetchone()[0] + def visit_column_onupdate(self, onupdate): if isinstance(onupdate.arg, expression.ClauseElement): return self.exec_default_sql(onupdate) diff --git a/test/sql/constraints.py b/test/sql/constraints.py index 93ba231ab..c5320ada3 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -174,12 +174,12 @@ class ConstraintTest(AssertMixin): capt = [] connection = testbase.db.connect() # TODO: hacky, put a real connection proxy in - ex = connection._Connection__execute + ex = connection._Connection__execute_raw def proxy(context): capt.append(context.statement) capt.append(repr(context.parameters)) ex(context) - connection._Connection__execute = proxy + connection._Connection__execute_raw = proxy schemagen = testbase.db.dialect.schemagenerator(testbase.db.dialect, connection) schemagen.traverse(events) |