summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/mssql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/databases/mssql.py')
-rw-r--r--lib/sqlalchemy/databases/mssql.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 934a64401..553e07a48 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -291,6 +291,17 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
super(MSSQLExecutionContext_pyodbc, self).pre_exec()
+ # where appropriate, issue "select scope_identity()" in the same statement
+ if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
+ self.statement += "; select scope_identity()"
+
+ def post_exec(self):
+ if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
+ # do nothing - id was fetched in dialect.do_execute()
+ self.HASIDENT = False
+ else:
+ super(MSSQLExecutionContext_pyodbc, self).post_exec()
+
class MSSQLDialect(ansisql.ANSIDialect):
colspecs = {
@@ -709,11 +720,24 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
return [[";".join (connectors)], {}]
def is_disconnect(self, e):
- return isinstance(e, self.dbapi.Error) and '[08S01]' in e.args[1]
+ return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
def create_execution_context(self, *args, **kwargs):
return MSSQLExecutionContext_pyodbc(self, *args, **kwargs)
+ def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
+ super(MSSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
+ if context and context.HASIDENT and (not context.IINSERT) and context.dialect.use_scope_identity:
+ import pyodbc
+ # fetch the last inserted id from the manipulated statement (pre_exec).
+ try:
+ row = cursor.fetchone()
+ except pyodbc.Error, e:
+ # if nocount OFF fetchone throws an exception and we have to jump over
+ # the rowcount to the resultset
+ cursor.nextset()
+ row = cursor.fetchone()
+ context._last_inserted_ids = [int(row[0])]
class MSSQLDialect_adodbapi(MSSQLDialect):
def import_dbapi(cls):