summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py47
1 files changed, 42 insertions, 5 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index a01b72e61..58cd80995 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -25,6 +25,41 @@ values present.
""")
+ISINSERT = util.symbol('ISINSERT')
+ISUPDATE = util.symbol('ISUPDATE')
+ISDELETE = util.symbol('ISDELETE')
+
+
+def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
+ restore_isinsert = compiler.isinsert
+ restore_isupdate = compiler.isupdate
+ restore_isdelete = compiler.isdelete
+
+ should_restore = (
+ restore_isinsert or restore_isupdate or restore_isdelete
+ ) or len(compiler.stack) > 1
+
+ if local_stmt_type is ISINSERT:
+ compiler.isupdate = False
+ compiler.isinsert = True
+ elif local_stmt_type is ISUPDATE:
+ compiler.isupdate = True
+ compiler.isinsert = False
+ elif local_stmt_type is ISDELETE:
+ if not should_restore:
+ compiler.isdelete = True
+ else:
+ assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
+
+ try:
+ if local_stmt_type in (ISINSERT, ISUPDATE):
+ return _get_crud_params(compiler, stmt, **kw)
+ finally:
+ if should_restore:
+ compiler.isinsert = restore_isinsert
+ compiler.isupdate = restore_isupdate
+ compiler.isdelete = restore_isdelete
+
def _get_crud_params(compiler, stmt, **kw):
"""create a set of tuples representing column/string pairs for use
@@ -59,7 +94,7 @@ def _get_crud_params(compiler, stmt, **kw):
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
_column_as_key, _getattr_col_key, _col_bind_name = \
- _key_getters_for_crud_column(compiler)
+ _key_getters_for_crud_column(compiler, stmt)
# if we have statement parameters - set defaults in the
# compiled params
@@ -128,15 +163,15 @@ def _create_bind_param(
return bindparam
-def _key_getters_for_crud_column(compiler):
- if compiler.isupdate and compiler.statement._extra_froms:
+def _key_getters_for_crud_column(compiler, stmt):
+ if compiler.isupdate and stmt._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
# the "main" table of the statement remains unqualified,
# allowing the most compatibility with a non-multi-table
# statement.
- _et = set(compiler.statement._extra_froms)
+ _et = set(stmt._extra_froms)
def _column_as_key(key):
str_key = elements._column_as_key(key)
@@ -609,7 +644,9 @@ def _get_returning_modifiers(compiler, stmt):
stmt.table.implicit_returning and
stmt._return_defaults)
else:
- implicit_return_defaults = False
+ # this line is unused, currently we are always
+ # isinsert or isupdate
+ implicit_return_defaults = False # pragma: no cover
if implicit_return_defaults:
if stmt._return_defaults is True: