summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-08-08 17:38:45 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-08-08 17:38:45 +0000
commitcbdccb7fd26da432ddf43ae1820656505acad37e (patch)
tree4c56e988526ef918190101d93fdc85c386313271 /lib/sqlalchemy/sql/compiler.py
parent3dc86785298c6144e832fd20dba4e372868ccc8a (diff)
downloadsqlalchemy-cbdccb7fd26da432ddf43ae1820656505acad37e.tar.gz
clean up the way we detect MSSQL's form of RETURNING
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py32
1 files changed, 22 insertions, 10 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index a47922cc5..d6187bcde 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -161,8 +161,17 @@ class SQLCompiler(engine.Compiled):
# level to define if this Compiled instance represents
# INSERT/UPDATE/DELETE
isdelete = isinsert = isupdate = False
+
+ # holds the "returning" collection of columns if
+ # the statement is CRUD and defines returning columns
+ # either implicitly or explicitly
returning = None
+ # set to True classwide to generate RETURNING
+ # clauses before the VALUES or WHERE clause (i.e. MSSQL)
+ returning_precedes_values = False
+
+
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -699,10 +708,8 @@ class SQLCompiler(engine.Compiled):
self.returning = self.returning or insert_stmt._returning
returning_clause = self.returning_clause(insert_stmt, self.returning)
- # cheating
- if returning_clause.startswith("OUTPUT"):
+ if self.returning_precedes_values:
text += " " + returning_clause
- returning_clause = None
if not colparams and supports_default_values:
text += " DEFAULT VALUES"
@@ -710,7 +717,7 @@ class SQLCompiler(engine.Compiled):
text += " VALUES (%s)" % \
', '.join([c[1] for c in colparams])
- if self.returning and returning_clause:
+ if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
return text
@@ -732,14 +739,14 @@ class SQLCompiler(engine.Compiled):
if update_stmt._returning:
self.returning = update_stmt._returning
returning_clause = self.returning_clause(update_stmt, update_stmt._returning)
- if returning_clause.startswith("OUTPUT"):
+
+ if self.returning_precedes_values:
text += " " + returning_clause
- returning_clause = None
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
- if self.returning and returning_clause:
+ if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
self.stack.pop(-1)
@@ -755,6 +762,11 @@ class SQLCompiler(engine.Compiled):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
+ Also generates the Compiled object's postfetch, prefetch, and returning
+ column collections, used for default handling and ultimately
+ populating the ResultProxy's prefetch_cols() and postfetch_cols()
+ collections.
+
"""
self.postfetch = []
@@ -880,14 +892,14 @@ class SQLCompiler(engine.Compiled):
if delete_stmt._returning:
self.returning = delete_stmt._returning
returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning)
- if returning_clause.startswith("OUTPUT"):
+
+ if self.returning_precedes_values:
text += " " + returning_clause
- returning_clause = None
if delete_stmt._whereclause:
text += " WHERE " + self.process(delete_stmt._whereclause)
- if self.returning and returning_clause:
+ if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
self.stack.pop(-1)