diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 57 |
1 files changed, 56 insertions, 1 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index af168cd85..17a0acf20 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -20,6 +20,7 @@ from .exclusions import db_spec from .util import fail from .. import exc as sa_exc from .. import schema +from .. import sql from .. import types as sqltypes from .. import util from ..engine import default @@ -441,7 +442,61 @@ class AssertsCompiledSQL(object): if compile_kwargs: kw["compile_kwargs"] = compile_kwargs - c = clause.compile(dialect=dialect, **kw) + class DontAccess(object): + def __getattribute__(self, key): + raise NotImplementedError( + "compiler accessed .statement; use " + "compiler.current_executable" + ) + + class CheckCompilerAccess(object): + def __init__(self, test_statement): + self.test_statement = test_statement + self._annotations = {} + self.supports_execution = getattr( + test_statement, "supports_execution", False + ) + if self.supports_execution: + self._execution_options = test_statement._execution_options + + if isinstance( + test_statement, (sql.Insert, sql.Update, sql.Delete) + ): + self._returning = test_statement._returning + if isinstance(test_statement, (sql.Insert, sql.Update)): + self._inline = test_statement._inline + self._return_defaults = test_statement._return_defaults + + def _default_dialect(self): + return self.test_statement._default_dialect() + + def compile(self, dialect, **kw): + return self.test_statement.compile.__func__( + self, dialect=dialect, **kw + ) + + def _compiler(self, dialect, **kw): + return self.test_statement._compiler.__func__( + self, dialect, **kw + ) + + def _compiler_dispatch(self, compiler, **kwargs): + if hasattr(compiler, "statement"): + with mock.patch.object( + compiler, "statement", DontAccess() + ): + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + else: + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + + # no construct can assume it's the "top level" construct in all cases + # as anything can be nested. ensure constructs don't assume they + # are the "self.statement" element + c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw) param_str = repr(getattr(c, "params", {})) if util.py3k: |