diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-10-19 10:19:29 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-10-19 13:13:15 -0400 |
commit | 3e49b8d0519aa024842206a2fb664a4ad83796d6 (patch) | |
tree | 9eb3cf01a68c42bb77a0bf76e99865d7c6370262 | |
parent | 296c84313ab29bf9599634f38caaf7dd092e4e23 (diff) | |
download | sqlalchemy-3e49b8d0519aa024842206a2fb664a4ad83796d6.tar.gz |
Ensure no compiler visit method tries to access .statement
Fixed structural compiler issue where some constructs such as MySQL /
PostgreSQL "on conflict / on duplicate key" would rely upon the state of
the :class:`_sql.Compiler` object being fixed against their statement as
the top level statement, which would fail in cases where those statements
are branched from a different context, such as a DDL construct linked to a
SQL statement.
Fixes: #5656
Change-Id: I568bf40adc7edcf72ea6c7fd6eb9d07790de189e
-rw-r--r-- | doc/build/changelog/unreleased_13/5656.rst | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 43 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 57 | ||||
-rw-r--r-- | test/sql/test_compiler.py | 48 |
7 files changed, 167 insertions, 15 deletions
diff --git a/doc/build/changelog/unreleased_13/5656.rst b/doc/build/changelog/unreleased_13/5656.rst new file mode 100644 index 000000000..cdec60842 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5656.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, sql + :tickets: 5656 + + Fixed structural compiler issue where some constructs such as MySQL / + PostgreSQL "on conflict / on duplicate key" would rely upon the state of + the :class:`_sql.Compiler` object being fixed against their statement as + the top level statement, which would fail in cases where those statements + are branched from a different context, such as a DDL construct linked to a + SQL statement. + diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 3e6c676ab..77f65799c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1461,6 +1461,8 @@ class MySQLCompiler(compiler.SQLCompiler): return self._render_json_extract_from_binary(binary, operator, **kw) def visit_on_duplicate_key_update(self, on_duplicate, **kw): + statement = self.current_executable + if on_duplicate._parameter_ordering: parameter_ordering = [ coercions.expect(roles.DMLColumnRole, key) @@ -1468,14 +1470,12 @@ class MySQLCompiler(compiler.SQLCompiler): ] ordered_keys = set(parameter_ordering) cols = [ - self.statement.table.c[key] + statement.table.c[key] for key in parameter_ordering - if key in self.statement.table.c - ] + [ - c for c in self.statement.table.c if c.key not in ordered_keys - ] + if key in statement.table.c + ] + [c for c in statement.table.c if c.key not in ordered_keys] else: - cols = self.statement.table.c + cols = statement.table.c clauses = [] # traverses through all table columns to preserve table column order diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 44272b9d3..ea6921b2d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2101,7 +2101,7 @@ class PGCompiler(compiler.SQLCompiler): "Additional column names not matching " "any column keys in table '%s': %s" % ( - self.statement.table.name, + self.current_executable.table.name, (", ".join("'%s'" % c for c in set_parameters)), ) ) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 022f6611f..4a0b2d07d 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -2171,10 +2171,9 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): # if we are against a lambda statement we might not be the # topmost object that received per-execute annotations - top_level_stmt = compiler.statement + if ( - top_level_stmt._annotations.get("synchronize_session", None) - == "fetch" + compiler._annotations.get("synchronize_session", None) == "fetch" and compiler.dialect.full_returning ): new_stmt = new_stmt.returning(*mapper.primary_key) @@ -2287,8 +2286,6 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper - top_level_stmt = compiler.statement - self.extra_criteria_entities = {} extra_criteria_attributes = {} @@ -2305,7 +2302,7 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): if ( mapper - and top_level_stmt._annotations.get("synchronize_session", None) + and compiler._annotations.get("synchronize_session", None) == "fetch" and compiler.dialect.full_returning ): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8718e15ea..10499975c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -377,12 +377,14 @@ class Compiled(object): schema_translate_map = None - execution_options = util.immutabledict() + execution_options = util.EMPTY_DICT """ Execution options propagated from the statement. In some cases, sub-elements of the statement can modify these. """ + _annotations = util.EMPTY_DICT + compile_state = None """Optional :class:`.CompileState` object that maintains additional state used by the compiler. @@ -474,6 +476,7 @@ class Compiled(object): if statement is not None: self.statement = statement self.can_execute = statement.supports_execution + self._annotations = statement._annotations if self.can_execute: self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) @@ -799,6 +802,44 @@ class SQLCompiler(Compiled): self._process_parameters_for_postcompile(_populate_self=True) @property + def current_executable(self): + """Return the current 'executable' that is being compiled. + + This is currently the :class:`_sql.Select`, :class:`_sql.Insert`, + :class:`_sql.Update`, :class:`_sql.Delete`, + :class:`_sql.CompoundSelect` object that is being compiled. + Specifically it's assigned to the ``self.stack`` list of elements. + + When a statement like the above is being compiled, it normally + is also assigned to the ``.statement`` attribute of the + :class:`_sql.Compiler` object. However, all SQL constructs are + ultimately nestable, and this attribute should never be consulted + by a ``visit_`` method, as it is not guaranteed to be assigned + nor guaranteed to correspond to the current statement being compiled. + + .. versionadded:: 1.3.21 + + For compatibility with previous versions, use the following + recipe:: + + statement = getattr(self, "current_executable", False) + if statement is False: + statement = self.stack[-1]["selectable"] + + For versions 1.4 and above, ensure only .current_executable + is used; the format of "self.stack" may change. + + + """ + try: + return self.stack[-1]["selectable"] + except IndexError as ie: + util.raise_( + IndexError("Compiler does not have a stack entry"), + replace_context=ie, + ) + + @property def prefetch(self): return list(self.insert_prefetch + self.update_prefetch) diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index af168cd85..17a0acf20 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -20,6 +20,7 @@ from .exclusions import db_spec from .util import fail from .. import exc as sa_exc from .. import schema +from .. import sql from .. import types as sqltypes from .. import util from ..engine import default @@ -441,7 +442,61 @@ class AssertsCompiledSQL(object): if compile_kwargs: kw["compile_kwargs"] = compile_kwargs - c = clause.compile(dialect=dialect, **kw) + class DontAccess(object): + def __getattribute__(self, key): + raise NotImplementedError( + "compiler accessed .statement; use " + "compiler.current_executable" + ) + + class CheckCompilerAccess(object): + def __init__(self, test_statement): + self.test_statement = test_statement + self._annotations = {} + self.supports_execution = getattr( + test_statement, "supports_execution", False + ) + if self.supports_execution: + self._execution_options = test_statement._execution_options + + if isinstance( + test_statement, (sql.Insert, sql.Update, sql.Delete) + ): + self._returning = test_statement._returning + if isinstance(test_statement, (sql.Insert, sql.Update)): + self._inline = test_statement._inline + self._return_defaults = test_statement._return_defaults + + def _default_dialect(self): + return self.test_statement._default_dialect() + + def compile(self, dialect, **kw): + return self.test_statement.compile.__func__( + self, dialect=dialect, **kw + ) + + def _compiler(self, dialect, **kw): + return self.test_statement._compiler.__func__( + self, dialect, **kw + ) + + def _compiler_dispatch(self, compiler, **kwargs): + if hasattr(compiler, "statement"): + with mock.patch.object( + compiler, "statement", DontAccess() + ): + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + else: + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + + # no construct can assume it's the "top level" construct in all cases + # as anything can be nested. ensure constructs don't assume they + # are the "self.statement" element + c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw) param_str = repr(getattr(c, "params", {})) if util.py3k: diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index ef2f75b2d..c9e1d9ab4 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -77,6 +77,7 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.elements import BooleanClauseList +from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseList from sqlalchemy.sql.expression import HasPrefixes from sqlalchemy.testing import assert_raises @@ -158,6 +159,53 @@ keyed = Table( ) +class TestCompilerFixture(fixtures.TestBase, AssertsCompiledSQL): + def test_dont_access_statement(self): + def visit_foobar(self, element, **kw): + self.statement.table + + class Foobar(ClauseElement): + __visit_name__ = "foobar" + + with mock.patch.object( + testing.db.dialect.statement_compiler, + "visit_foobar", + visit_foobar, + create=True, + ): + assert_raises_message( + NotImplementedError, + "compiler accessed .statement; use " + "compiler.current_executable", + self.assert_compile, + Foobar(), + "", + ) + + def test_no_stack(self): + def visit_foobar(self, element, **kw): + self.current_executable.table + + class Foobar(ClauseElement): + __visit_name__ = "foobar" + + with mock.patch.object( + testing.db.dialect.statement_compiler, + "visit_foobar", + visit_foobar, + create=True, + ): + compiler = testing.db.dialect.statement_compiler( + testing.db.dialect, None + ) + assert_raises_message( + IndexError, + "Compiler does not have a stack entry", + compiler.process, + Foobar(), + ) + + class SelectTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" |