diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index a6e3c8764..4416fe630 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -67,10 +67,13 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - def __init__(self, statement, params=None, dialect="default"): + def __init__( + self, statement, params=None, dialect="default", enable_returning=False + ): self.statement = statement self.params = params self.dialect = dialect + self.enable_returning = enable_returning def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r"[\n\t]", "", self.statement) @@ -82,14 +85,14 @@ class CompiledSQL(SQLMatchRule): # this is currently what tests are expecting # dialect.supports_default_values = True dialect.supports_default_metavalue = True + + if self.enable_returning: + dialect.insert_returning = ( + dialect.update_returning + ) = dialect.delete_returning = True return dialect else: - # ugh - if self.dialect == "postgresql": - params = {"implicit_returning": True} - else: - params = {} - return url.URL.create(self.dialect).get_dialect()(**params) + return url.URL.create(self.dialect).get_dialect()() def _received_statement(self, execute_observed): """reconstruct the statement and params in terms @@ -221,12 +224,15 @@ class CompiledSQL(SQLMatchRule): class RegexSQL(CompiledSQL): - def __init__(self, regex, params=None, dialect="default"): + def __init__( + self, regex, params=None, dialect="default", enable_returning=False + ): SQLMatchRule.__init__(self) self.regex = re.compile(regex) self.orig_regex = regex self.params = params self.dialect = dialect + self.enable_returning = enable_returning def _failure_message(self, execute_observed, expected_params): return ( |