summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index a6e3c8764..4416fe630 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -67,10 +67,13 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
- def __init__(self, statement, params=None, dialect="default"):
+ def __init__(
+ self, statement, params=None, dialect="default", enable_returning=False
+ ):
self.statement = statement
self.params = params
self.dialect = dialect
+ self.enable_returning = enable_returning
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r"[\n\t]", "", self.statement)
@@ -82,14 +85,14 @@ class CompiledSQL(SQLMatchRule):
# this is currently what tests are expecting
# dialect.supports_default_values = True
dialect.supports_default_metavalue = True
+
+ if self.enable_returning:
+ dialect.insert_returning = (
+ dialect.update_returning
+ ) = dialect.delete_returning = True
return dialect
else:
- # ugh
- if self.dialect == "postgresql":
- params = {"implicit_returning": True}
- else:
- params = {}
- return url.URL.create(self.dialect).get_dialect()(**params)
+ return url.URL.create(self.dialect).get_dialect()()
def _received_statement(self, execute_observed):
"""reconstruct the statement and params in terms
@@ -221,12 +224,15 @@ class CompiledSQL(SQLMatchRule):
class RegexSQL(CompiledSQL):
- def __init__(self, regex, params=None, dialect="default"):
+ def __init__(
+ self, regex, params=None, dialect="default", enable_returning=False
+ ):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
self.dialect = dialect
+ self.enable_returning = enable_returning
def _failure_message(self, execute_observed, expected_params):
return (