summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-10-19 10:19:29 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-10-19 13:13:15 -0400
commit3e49b8d0519aa024842206a2fb664a4ad83796d6 (patch)
tree9eb3cf01a68c42bb77a0bf76e99865d7c6370262
parent296c84313ab29bf9599634f38caaf7dd092e4e23 (diff)
downloadsqlalchemy-3e49b8d0519aa024842206a2fb664a4ad83796d6.tar.gz
Ensure no compiler visit method tries to access .statement
Fixed structural compiler issue where some constructs such as MySQL / PostgreSQL "on conflict / on duplicate key" would rely upon the state of the :class:`_sql.Compiler` object being fixed against their statement as the top level statement, which would fail in cases where those statements are branched from a different context, such as a DDL construct linked to a SQL statement. Fixes: #5656 Change-Id: I568bf40adc7edcf72ea6c7fd6eb9d07790de189e
-rw-r--r--doc/build/changelog/unreleased_13/5656.rst11
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py12
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2
-rw-r--r--lib/sqlalchemy/orm/persistence.py9
-rw-r--r--lib/sqlalchemy/sql/compiler.py43
-rw-r--r--lib/sqlalchemy/testing/assertions.py57
-rw-r--r--test/sql/test_compiler.py48
7 files changed, 167 insertions, 15 deletions
diff --git a/doc/build/changelog/unreleased_13/5656.rst b/doc/build/changelog/unreleased_13/5656.rst
new file mode 100644
index 000000000..cdec60842
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/5656.rst
@@ -0,0 +1,11 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 5656
+
+ Fixed structural compiler issue where some constructs such as MySQL /
+ PostgreSQL "on conflict / on duplicate key" would rely upon the state of
+ the :class:`_sql.Compiler` object being fixed against their statement as
+ the top level statement, which would fail in cases where those statements
+ are branched from a different context, such as a DDL construct linked to a
+ SQL statement.
+
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 3e6c676ab..77f65799c 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1461,6 +1461,8 @@ class MySQLCompiler(compiler.SQLCompiler):
return self._render_json_extract_from_binary(binary, operator, **kw)
def visit_on_duplicate_key_update(self, on_duplicate, **kw):
+ statement = self.current_executable
+
if on_duplicate._parameter_ordering:
parameter_ordering = [
coercions.expect(roles.DMLColumnRole, key)
@@ -1468,14 +1470,12 @@ class MySQLCompiler(compiler.SQLCompiler):
]
ordered_keys = set(parameter_ordering)
cols = [
- self.statement.table.c[key]
+ statement.table.c[key]
for key in parameter_ordering
- if key in self.statement.table.c
- ] + [
- c for c in self.statement.table.c if c.key not in ordered_keys
- ]
+ if key in statement.table.c
+ ] + [c for c in statement.table.c if c.key not in ordered_keys]
else:
- cols = self.statement.table.c
+ cols = statement.table.c
clauses = []
# traverses through all table columns to preserve table column order
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 44272b9d3..ea6921b2d 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2101,7 +2101,7 @@ class PGCompiler(compiler.SQLCompiler):
"Additional column names not matching "
"any column keys in table '%s': %s"
% (
- self.statement.table.name,
+ self.current_executable.table.name,
(", ".join("'%s'" % c for c in set_parameters)),
)
)
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 022f6611f..4a0b2d07d 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -2171,10 +2171,9 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
# if we are against a lambda statement we might not be the
# topmost object that received per-execute annotations
- top_level_stmt = compiler.statement
+
if (
- top_level_stmt._annotations.get("synchronize_session", None)
- == "fetch"
+ compiler._annotations.get("synchronize_session", None) == "fetch"
and compiler.dialect.full_returning
):
new_stmt = new_stmt.returning(*mapper.primary_key)
@@ -2287,8 +2286,6 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
- top_level_stmt = compiler.statement
-
self.extra_criteria_entities = {}
extra_criteria_attributes = {}
@@ -2305,7 +2302,7 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
if (
mapper
- and top_level_stmt._annotations.get("synchronize_session", None)
+ and compiler._annotations.get("synchronize_session", None)
== "fetch"
and compiler.dialect.full_returning
):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8718e15ea..10499975c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -377,12 +377,14 @@ class Compiled(object):
schema_translate_map = None
- execution_options = util.immutabledict()
+ execution_options = util.EMPTY_DICT
"""
Execution options propagated from the statement. In some cases,
sub-elements of the statement can modify these.
"""
+ _annotations = util.EMPTY_DICT
+
compile_state = None
"""Optional :class:`.CompileState` object that maintains additional
state used by the compiler.
@@ -474,6 +476,7 @@ class Compiled(object):
if statement is not None:
self.statement = statement
self.can_execute = statement.supports_execution
+ self._annotations = statement._annotations
if self.can_execute:
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
@@ -799,6 +802,44 @@ class SQLCompiler(Compiled):
self._process_parameters_for_postcompile(_populate_self=True)
@property
+ def current_executable(self):
+ """Return the current 'executable' that is being compiled.
+
+ This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
+ :class:`_sql.Update`, :class:`_sql.Delete`,
+ :class:`_sql.CompoundSelect` object that is being compiled.
+ Specifically it's assigned to the ``self.stack`` list of elements.
+
+ When a statement like the above is being compiled, it normally
+ is also assigned to the ``.statement`` attribute of the
+ :class:`_sql.Compiler` object. However, all SQL constructs are
+ ultimately nestable, and this attribute should never be consulted
+ by a ``visit_`` method, as it is not guaranteed to be assigned
+ nor guaranteed to correspond to the current statement being compiled.
+
+ .. versionadded:: 1.3.21
+
+ For compatibility with previous versions, use the following
+ recipe::
+
+ statement = getattr(self, "current_executable", False)
+ if statement is False:
+ statement = self.stack[-1]["selectable"]
+
+ For versions 1.4 and above, ensure only .current_executable
+ is used; the format of "self.stack" may change.
+
+
+ """
+ try:
+ return self.stack[-1]["selectable"]
+ except IndexError as ie:
+ util.raise_(
+ IndexError("Compiler does not have a stack entry"),
+ replace_context=ie,
+ )
+
+ @property
def prefetch(self):
return list(self.insert_prefetch + self.update_prefetch)
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index af168cd85..17a0acf20 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -20,6 +20,7 @@ from .exclusions import db_spec
from .util import fail
from .. import exc as sa_exc
from .. import schema
+from .. import sql
from .. import types as sqltypes
from .. import util
from ..engine import default
@@ -441,7 +442,61 @@ class AssertsCompiledSQL(object):
if compile_kwargs:
kw["compile_kwargs"] = compile_kwargs
- c = clause.compile(dialect=dialect, **kw)
+ class DontAccess(object):
+ def __getattribute__(self, key):
+ raise NotImplementedError(
+ "compiler accessed .statement; use "
+ "compiler.current_executable"
+ )
+
+ class CheckCompilerAccess(object):
+ def __init__(self, test_statement):
+ self.test_statement = test_statement
+ self._annotations = {}
+ self.supports_execution = getattr(
+ test_statement, "supports_execution", False
+ )
+ if self.supports_execution:
+ self._execution_options = test_statement._execution_options
+
+ if isinstance(
+ test_statement, (sql.Insert, sql.Update, sql.Delete)
+ ):
+ self._returning = test_statement._returning
+ if isinstance(test_statement, (sql.Insert, sql.Update)):
+ self._inline = test_statement._inline
+ self._return_defaults = test_statement._return_defaults
+
+ def _default_dialect(self):
+ return self.test_statement._default_dialect()
+
+ def compile(self, dialect, **kw):
+ return self.test_statement.compile.__func__(
+ self, dialect=dialect, **kw
+ )
+
+ def _compiler(self, dialect, **kw):
+ return self.test_statement._compiler.__func__(
+ self, dialect, **kw
+ )
+
+ def _compiler_dispatch(self, compiler, **kwargs):
+ if hasattr(compiler, "statement"):
+ with mock.patch.object(
+ compiler, "statement", DontAccess()
+ ):
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+ else:
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+
+ # no construct can assume it's the "top level" construct in all cases
+ # as anything can be nested. ensure constructs don't assume they
+ # are the "self.statement" element
+ c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
param_str = repr(getattr(c, "params", {}))
if util.py3k:
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index ef2f75b2d..c9e1d9ab4 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -77,6 +77,7 @@ from sqlalchemy.sql import operators
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql.elements import BooleanClauseList
+from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.sql.expression import ClauseList
from sqlalchemy.sql.expression import HasPrefixes
from sqlalchemy.testing import assert_raises
@@ -158,6 +159,53 @@ keyed = Table(
)
+class TestCompilerFixture(fixtures.TestBase, AssertsCompiledSQL):
+ def test_dont_access_statement(self):
+ def visit_foobar(self, element, **kw):
+ self.statement.table
+
+ class Foobar(ClauseElement):
+ __visit_name__ = "foobar"
+
+ with mock.patch.object(
+ testing.db.dialect.statement_compiler,
+ "visit_foobar",
+ visit_foobar,
+ create=True,
+ ):
+ assert_raises_message(
+ NotImplementedError,
+ "compiler accessed .statement; use "
+ "compiler.current_executable",
+ self.assert_compile,
+ Foobar(),
+ "",
+ )
+
+ def test_no_stack(self):
+ def visit_foobar(self, element, **kw):
+ self.current_executable.table
+
+ class Foobar(ClauseElement):
+ __visit_name__ = "foobar"
+
+ with mock.patch.object(
+ testing.db.dialect.statement_compiler,
+ "visit_foobar",
+ visit_foobar,
+ create=True,
+ ):
+ compiler = testing.db.dialect.statement_compiler(
+ testing.db.dialect, None
+ )
+ assert_raises_message(
+ IndexError,
+ "Compiler does not have a stack entry",
+ compiler.process,
+ Foobar(),
+ )
+
+
class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"