summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r--lib/sqlalchemy/testing/assertions.py57
1 files changed, 56 insertions, 1 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index af168cd85..17a0acf20 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -20,6 +20,7 @@ from .exclusions import db_spec
from .util import fail
from .. import exc as sa_exc
from .. import schema
+from .. import sql
from .. import types as sqltypes
from .. import util
from ..engine import default
@@ -441,7 +442,61 @@ class AssertsCompiledSQL(object):
if compile_kwargs:
kw["compile_kwargs"] = compile_kwargs
- c = clause.compile(dialect=dialect, **kw)
+ class DontAccess(object):
+ def __getattribute__(self, key):
+ raise NotImplementedError(
+ "compiler accessed .statement; use "
+ "compiler.current_executable"
+ )
+
+ class CheckCompilerAccess(object):
+ def __init__(self, test_statement):
+ self.test_statement = test_statement
+ self._annotations = {}
+ self.supports_execution = getattr(
+ test_statement, "supports_execution", False
+ )
+ if self.supports_execution:
+ self._execution_options = test_statement._execution_options
+
+ if isinstance(
+ test_statement, (sql.Insert, sql.Update, sql.Delete)
+ ):
+ self._returning = test_statement._returning
+ if isinstance(test_statement, (sql.Insert, sql.Update)):
+ self._inline = test_statement._inline
+ self._return_defaults = test_statement._return_defaults
+
+ def _default_dialect(self):
+ return self.test_statement._default_dialect()
+
+ def compile(self, dialect, **kw):
+ return self.test_statement.compile.__func__(
+ self, dialect=dialect, **kw
+ )
+
+ def _compiler(self, dialect, **kw):
+ return self.test_statement._compiler.__func__(
+ self, dialect, **kw
+ )
+
+ def _compiler_dispatch(self, compiler, **kwargs):
+ if hasattr(compiler, "statement"):
+ with mock.patch.object(
+ compiler, "statement", DontAccess()
+ ):
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+ else:
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+
+ # no construct can assume it's the "top level" construct in all cases
+ # as anything can be nested. ensure constructs don't assume they
+ # are the "self.statement" element
+ c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
param_str = repr(getattr(c, "params", {}))
if util.py3k: